core[patch]: Add ruff rule UP006(use PEP585 annotations) (#26574)

* Added rules `UPD006` now that Pydantic is v2+

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Christophe Bornet 2024-09-17 23:22:50 +02:00 committed by GitHub
parent 2ef4c9466f
commit 3a99467ccb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
70 changed files with 1222 additions and 1286 deletions

View File

@ -25,7 +25,7 @@ The schemas for the agents themselves are defined in langchain.agents.agent.
from __future__ import annotations from __future__ import annotations
import json import json
from typing import Any, List, Literal, Sequence, Union from typing import Any, Literal, Sequence, Union
from langchain_core.load.serializable import Serializable from langchain_core.load.serializable import Serializable
from langchain_core.messages import ( from langchain_core.messages import (
@ -71,7 +71,7 @@ class AgentAction(Serializable):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "agent"].""" Default is ["langchain", "schema", "agent"]."""
return ["langchain", "schema", "agent"] return ["langchain", "schema", "agent"]
@ -145,7 +145,7 @@ class AgentFinish(Serializable):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "agent"] return ["langchain", "schema", "agent"]

View File

@ -23,7 +23,7 @@ Cache directly competes with Memory. See documentation for Pros and Cons.
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Sequence, Tuple from typing import Any, Optional, Sequence
from langchain_core.outputs import Generation from langchain_core.outputs import Generation
from langchain_core.runnables import run_in_executor from langchain_core.runnables import run_in_executor
@ -157,7 +157,7 @@ class InMemoryCache(BaseCache):
Raises: Raises:
ValueError: If maxsize is less than or equal to 0. ValueError: If maxsize is less than or equal to 0.
""" """
self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {} self._cache: dict[tuple[str, str], RETURN_VAL_TYPE] = {}
if maxsize is not None and maxsize <= 0: if maxsize is not None and maxsize <= 0:
raise ValueError("maxsize must be greater than 0") raise ValueError("maxsize must be greater than 0")
self._maxsize = maxsize self._maxsize = maxsize

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union
from uuid import UUID from uuid import UUID
from tenacity import RetryCallState from tenacity import RetryCallState
@ -118,7 +118,7 @@ class ChainManagerMixin:
def on_chain_end( def on_chain_end(
self, self,
outputs: Dict[str, Any], outputs: dict[str, Any],
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
@ -222,13 +222,13 @@ class CallbackManagerMixin:
def on_llm_start( def on_llm_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
prompts: List[str], prompts: list[str],
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when LLM starts running. """Run when LLM starts running.
@ -249,13 +249,13 @@ class CallbackManagerMixin:
def on_chat_model_start( def on_chat_model_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
messages: List[List[BaseMessage]], messages: list[list[BaseMessage]],
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when a chat model starts running. """Run when a chat model starts running.
@ -280,13 +280,13 @@ class CallbackManagerMixin:
def on_retriever_start( def on_retriever_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
query: str, query: str,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when the Retriever starts running. """Run when the Retriever starts running.
@ -303,13 +303,13 @@ class CallbackManagerMixin:
def on_chain_start( def on_chain_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
inputs: Dict[str, Any], inputs: dict[str, Any],
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when a chain starts running. """Run when a chain starts running.
@ -326,14 +326,14 @@ class CallbackManagerMixin:
def on_tool_start( def on_tool_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
input_str: str, input_str: str,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when the tool starts running. """Run when the tool starts running.
@ -393,8 +393,8 @@ class RunManagerMixin:
data: Any, data: Any,
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Override to define a handler for a custom event. """Override to define a handler for a custom event.
@ -470,13 +470,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_llm_start( async def on_llm_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
prompts: List[str], prompts: list[str],
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when LLM starts running. """Run when LLM starts running.
@ -497,13 +497,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_chat_model_start( async def on_chat_model_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
messages: List[List[BaseMessage]], messages: list[list[BaseMessage]],
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when a chat model starts running. """Run when a chat model starts running.
@ -533,7 +533,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run on new LLM token. Only available when streaming is enabled. """Run on new LLM token. Only available when streaming is enabled.
@ -554,7 +554,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when LLM ends running. """Run when LLM ends running.
@ -573,7 +573,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when LLM errors. """Run when LLM errors.
@ -590,13 +590,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_chain_start( async def on_chain_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
inputs: Dict[str, Any], inputs: dict[str, Any],
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when a chain starts running. """Run when a chain starts running.
@ -613,11 +613,11 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_chain_end( async def on_chain_end(
self, self,
outputs: Dict[str, Any], outputs: dict[str, Any],
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when a chain ends running. """Run when a chain ends running.
@ -636,7 +636,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when chain errors. """Run when chain errors.
@ -651,14 +651,14 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_tool_start( async def on_tool_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
input_str: str, input_str: str,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when the tool starts running. """Run when the tool starts running.
@ -680,7 +680,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when the tool ends running. """Run when the tool ends running.
@ -699,7 +699,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when tool errors. """Run when tool errors.
@ -718,7 +718,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run on an arbitrary text. """Run on an arbitrary text.
@ -754,7 +754,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run on agent action. """Run on agent action.
@ -773,7 +773,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run on the agent end. """Run on the agent end.
@ -788,13 +788,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_retriever_start( async def on_retriever_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
query: str, query: str,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run on the retriever start. """Run on the retriever start.
@ -815,7 +815,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run on the retriever end. """Run on the retriever end.
@ -833,7 +833,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run on retriever error. """Run on retriever error.
@ -852,8 +852,8 @@ class AsyncCallbackHandler(BaseCallbackHandler):
data: Any, data: Any,
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Override to define a handler for a custom event. """Override to define a handler for a custom event.
@ -880,14 +880,14 @@ class BaseCallbackManager(CallbackManagerMixin):
def __init__( def __init__(
self, self,
handlers: List[BaseCallbackHandler], handlers: list[BaseCallbackHandler],
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, inheritable_handlers: Optional[list[BaseCallbackHandler]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
*, *,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
inheritable_tags: Optional[List[str]] = None, inheritable_tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None, inheritable_metadata: Optional[dict[str, Any]] = None,
) -> None: ) -> None:
"""Initialize callback manager. """Initialize callback manager.
@ -901,8 +901,8 @@ class BaseCallbackManager(CallbackManagerMixin):
Default is None. Default is None.
metadata (Optional[Dict[str, Any]]): The metadata. Default is None. metadata (Optional[Dict[str, Any]]): The metadata. Default is None.
""" """
self.handlers: List[BaseCallbackHandler] = handlers self.handlers: list[BaseCallbackHandler] = handlers
self.inheritable_handlers: List[BaseCallbackHandler] = ( self.inheritable_handlers: list[BaseCallbackHandler] = (
inheritable_handlers or [] inheritable_handlers or []
) )
self.parent_run_id: Optional[UUID] = parent_run_id self.parent_run_id: Optional[UUID] = parent_run_id
@ -1002,7 +1002,7 @@ class BaseCallbackManager(CallbackManagerMixin):
self.inheritable_handlers.remove(handler) self.inheritable_handlers.remove(handler)
def set_handlers( def set_handlers(
self, handlers: List[BaseCallbackHandler], inherit: bool = True self, handlers: list[BaseCallbackHandler], inherit: bool = True
) -> None: ) -> None:
"""Set handlers as the only handlers on the callback manager. """Set handlers as the only handlers on the callback manager.
@ -1024,7 +1024,7 @@ class BaseCallbackManager(CallbackManagerMixin):
""" """
self.set_handlers([handler], inherit=inherit) self.set_handlers([handler], inherit=inherit)
def add_tags(self, tags: List[str], inherit: bool = True) -> None: def add_tags(self, tags: list[str], inherit: bool = True) -> None:
"""Add tags to the callback manager. """Add tags to the callback manager.
Args: Args:
@ -1038,7 +1038,7 @@ class BaseCallbackManager(CallbackManagerMixin):
if inherit: if inherit:
self.inheritable_tags.extend(tags) self.inheritable_tags.extend(tags)
def remove_tags(self, tags: List[str]) -> None: def remove_tags(self, tags: list[str]) -> None:
"""Remove tags from the callback manager. """Remove tags from the callback manager.
Args: Args:
@ -1048,7 +1048,7 @@ class BaseCallbackManager(CallbackManagerMixin):
self.tags.remove(tag) self.tags.remove(tag)
self.inheritable_tags.remove(tag) self.inheritable_tags.remove(tag)
def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None: def add_metadata(self, metadata: dict[str, Any], inherit: bool = True) -> None:
"""Add metadata to the callback manager. """Add metadata to the callback manager.
Args: Args:
@ -1059,7 +1059,7 @@ class BaseCallbackManager(CallbackManagerMixin):
if inherit: if inherit:
self.inheritable_metadata.update(metadata) self.inheritable_metadata.update(metadata)
def remove_metadata(self, keys: List[str]) -> None: def remove_metadata(self, keys: list[str]) -> None:
"""Remove metadata from the callback manager. """Remove metadata from the callback manager.
Args: Args:

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, Optional, TextIO, cast from typing import Any, Optional, TextIO, cast
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler from langchain_core.callbacks import BaseCallbackHandler
@ -35,7 +35,7 @@ class FileCallbackHandler(BaseCallbackHandler):
self.file.close() self.file.close()
def on_chain_start( def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
) -> None: ) -> None:
"""Print out that we are entering a chain. """Print out that we are entering a chain.
@ -51,7 +51,7 @@ class FileCallbackHandler(BaseCallbackHandler):
file=self.file, file=self.file,
) )
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain. """Print out that we finished a chain.
Args: Args:

View File

@ -14,9 +14,7 @@ from typing import (
AsyncGenerator, AsyncGenerator,
Callable, Callable,
Coroutine, Coroutine,
Dict,
Generator, Generator,
List,
Optional, Optional,
Sequence, Sequence,
Type, Type,
@ -64,12 +62,12 @@ def trace_as_chain_group(
group_name: str, group_name: str,
callback_manager: Optional[CallbackManager] = None, callback_manager: Optional[CallbackManager] = None,
*, *,
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[dict[str, Any]] = None,
project_name: Optional[str] = None, project_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None, example_id: Optional[Union[str, UUID]] = None,
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
) -> Generator[CallbackManagerForChainGroup, None, None]: ) -> Generator[CallbackManagerForChainGroup, None, None]:
"""Get a callback manager for a chain group in a context manager. """Get a callback manager for a chain group in a context manager.
Useful for grouping different calls together as a single run even if Useful for grouping different calls together as a single run even if
@ -144,12 +142,12 @@ async def atrace_as_chain_group(
group_name: str, group_name: str,
callback_manager: Optional[AsyncCallbackManager] = None, callback_manager: Optional[AsyncCallbackManager] = None,
*, *,
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[dict[str, Any]] = None,
project_name: Optional[str] = None, project_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None, example_id: Optional[Union[str, UUID]] = None,
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]: ) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]:
"""Get an async callback manager for a chain group in a context manager. """Get an async callback manager for a chain group in a context manager.
Useful for grouping different async calls together as a single run even if Useful for grouping different async calls together as a single run even if
@ -240,7 +238,7 @@ def shielded(func: Func) -> Func:
def handle_event( def handle_event(
handlers: List[BaseCallbackHandler], handlers: list[BaseCallbackHandler],
event_name: str, event_name: str,
ignore_condition_name: Optional[str], ignore_condition_name: Optional[str],
*args: Any, *args: Any,
@ -258,10 +256,10 @@ def handle_event(
*args: The arguments to pass to the event handler. *args: The arguments to pass to the event handler.
**kwargs: The keyword arguments to pass to the event handler **kwargs: The keyword arguments to pass to the event handler
""" """
coros: List[Coroutine[Any, Any, Any]] = [] coros: list[Coroutine[Any, Any, Any]] = []
try: try:
message_strings: Optional[List[str]] = None message_strings: Optional[list[str]] = None
for handler in handlers: for handler in handlers:
try: try:
if ignore_condition_name is None or not getattr( if ignore_condition_name is None or not getattr(
@ -318,7 +316,7 @@ def handle_event(
_run_coros(coros) _run_coros(coros)
def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None: def _run_coros(coros: list[Coroutine[Any, Any, Any]]) -> None:
if hasattr(asyncio, "Runner"): if hasattr(asyncio, "Runner"):
# Python 3.11+ # Python 3.11+
# Run the coroutines in a new event loop, taking care to # Run the coroutines in a new event loop, taking care to
@ -399,7 +397,7 @@ async def _ahandle_event_for_handler(
async def ahandle_event( async def ahandle_event(
handlers: List[BaseCallbackHandler], handlers: list[BaseCallbackHandler],
event_name: str, event_name: str,
ignore_condition_name: Optional[str], ignore_condition_name: Optional[str],
*args: Any, *args: Any,
@ -446,13 +444,13 @@ class BaseRunManager(RunManagerMixin):
self, self,
*, *,
run_id: UUID, run_id: UUID,
handlers: List[BaseCallbackHandler], handlers: list[BaseCallbackHandler],
inheritable_handlers: List[BaseCallbackHandler], inheritable_handlers: list[BaseCallbackHandler],
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
inheritable_tags: Optional[List[str]] = None, inheritable_tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None, inheritable_metadata: Optional[dict[str, Any]] = None,
) -> None: ) -> None:
"""Initialize the run manager. """Initialize the run manager.
@ -481,7 +479,7 @@ class BaseRunManager(RunManagerMixin):
self.inheritable_metadata = inheritable_metadata or {} self.inheritable_metadata = inheritable_metadata or {}
@classmethod @classmethod
def get_noop_manager(cls: Type[BRM]) -> BRM: def get_noop_manager(cls: type[BRM]) -> BRM:
"""Return a manager that doesn't perform any operations. """Return a manager that doesn't perform any operations.
Returns: Returns:
@ -824,7 +822,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
"""Callback manager for chain run.""" """Callback manager for chain run."""
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Union[dict[str, Any], Any], **kwargs: Any) -> None:
"""Run when chain ends running. """Run when chain ends running.
Args: Args:
@ -929,7 +927,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
@shielded @shielded
async def on_chain_end( async def on_chain_end(
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any self, outputs: Union[dict[str, Any], Any], **kwargs: Any
) -> None: ) -> None:
"""Run when a chain ends running. """Run when a chain ends running.
@ -1248,11 +1246,11 @@ class CallbackManager(BaseCallbackManager):
def on_llm_start( def on_llm_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
prompts: List[str], prompts: list[str],
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> List[CallbackManagerForLLMRun]: ) -> list[CallbackManagerForLLMRun]:
"""Run when LLM starts running. """Run when LLM starts running.
Args: Args:
@ -1299,11 +1297,11 @@ class CallbackManager(BaseCallbackManager):
def on_chat_model_start( def on_chat_model_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
messages: List[List[BaseMessage]], messages: list[list[BaseMessage]],
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> List[CallbackManagerForLLMRun]: ) -> list[CallbackManagerForLLMRun]:
"""Run when LLM starts running. """Run when LLM starts running.
Args: Args:
@ -1354,8 +1352,8 @@ class CallbackManager(BaseCallbackManager):
def on_chain_start( def on_chain_start(
self, self,
serialized: Optional[Dict[str, Any]], serialized: Optional[dict[str, Any]],
inputs: Union[Dict[str, Any], Any], inputs: Union[dict[str, Any], Any],
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> CallbackManagerForChainRun: ) -> CallbackManagerForChainRun:
@ -1398,11 +1396,11 @@ class CallbackManager(BaseCallbackManager):
def on_tool_start( def on_tool_start(
self, self,
serialized: Optional[Dict[str, Any]], serialized: Optional[dict[str, Any]],
input_str: str, input_str: str,
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> CallbackManagerForToolRun: ) -> CallbackManagerForToolRun:
"""Run when tool starts running. """Run when tool starts running.
@ -1453,7 +1451,7 @@ class CallbackManager(BaseCallbackManager):
def on_retriever_start( def on_retriever_start(
self, self,
serialized: Optional[Dict[str, Any]], serialized: Optional[dict[str, Any]],
query: str, query: str,
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
@ -1541,10 +1539,10 @@ class CallbackManager(BaseCallbackManager):
inheritable_callbacks: Callbacks = None, inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None, local_callbacks: Callbacks = None,
verbose: bool = False, verbose: bool = False,
inheritable_tags: Optional[List[str]] = None, inheritable_tags: Optional[list[str]] = None,
local_tags: Optional[List[str]] = None, local_tags: Optional[list[str]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None, inheritable_metadata: Optional[dict[str, Any]] = None,
local_metadata: Optional[Dict[str, Any]] = None, local_metadata: Optional[dict[str, Any]] = None,
) -> CallbackManager: ) -> CallbackManager:
"""Configure the callback manager. """Configure the callback manager.
@ -1583,8 +1581,8 @@ class CallbackManagerForChainGroup(CallbackManager):
def __init__( def __init__(
self, self,
handlers: List[BaseCallbackHandler], handlers: list[BaseCallbackHandler],
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, inheritable_handlers: Optional[list[BaseCallbackHandler]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
*, *,
parent_run_manager: CallbackManagerForChainRun, parent_run_manager: CallbackManagerForChainRun,
@ -1681,7 +1679,7 @@ class CallbackManagerForChainGroup(CallbackManager):
manager.add_handler(handler, inherit=True) manager.add_handler(handler, inherit=True)
return manager return manager
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Union[dict[str, Any], Any], **kwargs: Any) -> None:
"""Run when traced chain group ends. """Run when traced chain group ends.
Args: Args:
@ -1716,11 +1714,11 @@ class AsyncCallbackManager(BaseCallbackManager):
async def on_llm_start( async def on_llm_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
prompts: List[str], prompts: list[str],
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> List[AsyncCallbackManagerForLLMRun]: ) -> list[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running. """Run when LLM starts running.
Args: Args:
@ -1779,11 +1777,11 @@ class AsyncCallbackManager(BaseCallbackManager):
async def on_chat_model_start( async def on_chat_model_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
messages: List[List[BaseMessage]], messages: list[list[BaseMessage]],
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> List[AsyncCallbackManagerForLLMRun]: ) -> list[AsyncCallbackManagerForLLMRun]:
"""Async run when LLM starts running. """Async run when LLM starts running.
Args: Args:
@ -1840,8 +1838,8 @@ class AsyncCallbackManager(BaseCallbackManager):
async def on_chain_start( async def on_chain_start(
self, self,
serialized: Optional[Dict[str, Any]], serialized: Optional[dict[str, Any]],
inputs: Union[Dict[str, Any], Any], inputs: Union[dict[str, Any], Any],
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncCallbackManagerForChainRun: ) -> AsyncCallbackManagerForChainRun:
@ -1886,7 +1884,7 @@ class AsyncCallbackManager(BaseCallbackManager):
async def on_tool_start( async def on_tool_start(
self, self,
serialized: Optional[Dict[str, Any]], serialized: Optional[dict[str, Any]],
input_str: str, input_str: str,
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
@ -1975,7 +1973,7 @@ class AsyncCallbackManager(BaseCallbackManager):
async def on_retriever_start( async def on_retriever_start(
self, self,
serialized: Optional[Dict[str, Any]], serialized: Optional[dict[str, Any]],
query: str, query: str,
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
@ -2027,10 +2025,10 @@ class AsyncCallbackManager(BaseCallbackManager):
inheritable_callbacks: Callbacks = None, inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None, local_callbacks: Callbacks = None,
verbose: bool = False, verbose: bool = False,
inheritable_tags: Optional[List[str]] = None, inheritable_tags: Optional[list[str]] = None,
local_tags: Optional[List[str]] = None, local_tags: Optional[list[str]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None, inheritable_metadata: Optional[dict[str, Any]] = None,
local_metadata: Optional[Dict[str, Any]] = None, local_metadata: Optional[dict[str, Any]] = None,
) -> AsyncCallbackManager: ) -> AsyncCallbackManager:
"""Configure the async callback manager. """Configure the async callback manager.
@ -2069,8 +2067,8 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
def __init__( def __init__(
self, self,
handlers: List[BaseCallbackHandler], handlers: list[BaseCallbackHandler],
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, inheritable_handlers: Optional[list[BaseCallbackHandler]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
*, *,
parent_run_manager: AsyncCallbackManagerForChainRun, parent_run_manager: AsyncCallbackManagerForChainRun,
@ -2169,7 +2167,7 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
return manager return manager
async def on_chain_end( async def on_chain_end(
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any self, outputs: Union[dict[str, Any], Any], **kwargs: Any
) -> None: ) -> None:
"""Run when traced chain group ends. """Run when traced chain group ends.
@ -2202,14 +2200,14 @@ H = TypeVar("H", bound=BaseCallbackHandler, covariant=True)
def _configure( def _configure(
callback_manager_cls: Type[T], callback_manager_cls: type[T],
inheritable_callbacks: Callbacks = None, inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None, local_callbacks: Callbacks = None,
verbose: bool = False, verbose: bool = False,
inheritable_tags: Optional[List[str]] = None, inheritable_tags: Optional[list[str]] = None,
local_tags: Optional[List[str]] = None, local_tags: Optional[list[str]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None, inheritable_metadata: Optional[dict[str, Any]] = None,
local_metadata: Optional[Dict[str, Any]] = None, local_metadata: Optional[dict[str, Any]] = None,
) -> T: ) -> T:
"""Configure the callback manager. """Configure the callback manager.

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Optional
from langchain_core.callbacks.base import BaseCallbackHandler from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.utils import print_text from langchain_core.utils import print_text
@ -23,7 +23,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
self.color = color self.color = color
def on_chain_start( def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
) -> None: ) -> None:
"""Print out that we are entering a chain. """Print out that we are entering a chain.
@ -35,7 +35,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1]) class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") # noqa: T201 print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") # noqa: T201
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain. """Print out that we finished a chain.
Args: Args:

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import sys import sys
from typing import TYPE_CHECKING, Any, Dict, List from typing import TYPE_CHECKING, Any
from langchain_core.callbacks.base import BaseCallbackHandler from langchain_core.callbacks.base import BaseCallbackHandler
@ -17,7 +17,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
"""Callback handler for streaming. Only works with LLMs that support streaming.""" """Callback handler for streaming. Only works with LLMs that support streaming."""
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None: ) -> None:
"""Run when LLM starts running. """Run when LLM starts running.
@ -29,8 +29,8 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
def on_chat_model_start( def on_chat_model_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
messages: List[List[BaseMessage]], messages: list[list[BaseMessage]],
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when LLM starts running. """Run when LLM starts running.
@ -68,7 +68,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
""" """
def on_chain_start( def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
) -> None: ) -> None:
"""Run when a chain starts running. """Run when a chain starts running.
@ -78,7 +78,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
**kwargs (Any): Additional keyword arguments. **kwargs (Any): Additional keyword arguments.
""" """
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
"""Run when a chain ends running. """Run when a chain ends running.
Args: Args:
@ -95,7 +95,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
""" """
def on_tool_start( def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any self, serialized: dict[str, Any], input_str: str, **kwargs: Any
) -> None: ) -> None:
"""Run when the tool starts running. """Run when the tool starts running.

View File

@ -18,7 +18,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Sequence, Union from typing import Sequence, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -87,7 +87,7 @@ class BaseChatMessageHistory(ABC):
f.write("[]") f.write("[]")
""" """
messages: List[BaseMessage] messages: list[BaseMessage]
"""A property or attribute that returns a list of messages. """A property or attribute that returns a list of messages.
In general, getting the messages may involve IO to the underlying In general, getting the messages may involve IO to the underlying
@ -95,7 +95,7 @@ class BaseChatMessageHistory(ABC):
latency. latency.
""" """
async def aget_messages(self) -> List[BaseMessage]: async def aget_messages(self) -> list[BaseMessage]:
"""Async version of getting messages. """Async version of getting messages.
Can over-ride this method to provide an efficient async implementation. Can over-ride this method to provide an efficient async implementation.
@ -204,10 +204,10 @@ class InMemoryChatMessageHistory(BaseChatMessageHistory, BaseModel):
Stores messages in a memory list. Stores messages in a memory list.
""" """
messages: List[BaseMessage] = Field(default_factory=list) messages: list[BaseMessage] = Field(default_factory=list)
"""A list of messages stored in memory.""" """A list of messages stored in memory."""
async def aget_messages(self) -> List[BaseMessage]: async def aget_messages(self) -> list[BaseMessage]:
"""Async version of getting messages. """Async version of getting messages.
Can over-ride this method to provide an efficient async implementation. Can over-ride this method to provide an efficient async implementation.

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Optional from typing import TYPE_CHECKING, AsyncIterator, Iterator, Optional
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.runnables import run_in_executor from langchain_core.runnables import run_in_executor
@ -25,17 +25,17 @@ class BaseLoader(ABC): # noqa: B024
# Sub-classes should not implement this method directly. Instead, they # Sub-classes should not implement this method directly. Instead, they
# should implement the lazy load method. # should implement the lazy load method.
def load(self) -> List[Document]: def load(self) -> list[Document]:
"""Load data into Document objects.""" """Load data into Document objects."""
return list(self.lazy_load()) return list(self.lazy_load())
async def aload(self) -> List[Document]: async def aload(self) -> list[Document]:
"""Load data into Document objects.""" """Load data into Document objects."""
return [document async for document in self.alazy_load()] return [document async for document in self.alazy_load()]
def load_and_split( def load_and_split(
self, text_splitter: Optional[TextSplitter] = None self, text_splitter: Optional[TextSplitter] = None
) -> List[Document]: ) -> list[Document]:
"""Load Documents and split into chunks. Chunks are returned as Documents. """Load Documents and split into chunks. Chunks are returned as Documents.
Do not override this method. It should be considered to be deprecated! Do not override this method. It should be considered to be deprecated!
@ -108,7 +108,7 @@ class BaseBlobParser(ABC):
Generator of documents Generator of documents
""" """
def parse(self, blob: Blob) -> List[Document]: def parse(self, blob: Blob) -> list[Document]:
"""Eagerly parse the blob into a document or documents. """Eagerly parse the blob into a document or documents.
This is a convenience method for interactive development environment. This is a convenience method for interactive development environment.

View File

@ -4,7 +4,7 @@ import contextlib
import mimetypes import mimetypes
from io import BufferedReader, BytesIO from io import BufferedReader, BytesIO
from pathlib import PurePath from pathlib import PurePath
from typing import Any, Dict, Generator, List, Literal, Optional, Union, cast from typing import Any, Generator, Literal, Optional, Union, cast
from pydantic import ConfigDict, Field, field_validator, model_validator from pydantic import ConfigDict, Field, field_validator, model_validator
@ -138,7 +138,7 @@ class Blob(BaseMedia):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_blob_is_valid(cls, values: Dict[str, Any]) -> Any: def check_blob_is_valid(cls, values: dict[str, Any]) -> Any:
"""Verify that either data or path is provided.""" """Verify that either data or path is provided."""
if "data" not in values and "path" not in values: if "data" not in values and "path" not in values:
raise ValueError("Either data or path must be provided") raise ValueError("Either data or path must be provided")
@ -285,7 +285,7 @@ class Document(BaseMedia):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "document"] return ["langchain", "schema", "document"]

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC from abc import ABC
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -15,7 +15,7 @@ if TYPE_CHECKING:
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
def sorted_values(values: Dict[str, str]) -> List[Any]: def sorted_values(values: dict[str, str]) -> list[Any]:
"""Return a list of values in dict sorted by key. """Return a list of values in dict sorted by key.
Args: Args:
@ -35,12 +35,12 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
"""VectorStore that contains information about examples.""" """VectorStore that contains information about examples."""
k: int = 4 k: int = 4
"""Number of examples to select.""" """Number of examples to select."""
example_keys: Optional[List[str]] = None example_keys: Optional[list[str]] = None
"""Optional keys to filter examples to.""" """Optional keys to filter examples to."""
input_keys: Optional[List[str]] = None input_keys: Optional[list[str]] = None
"""Optional keys to filter input to. If provided, the search is based on """Optional keys to filter input to. If provided, the search is based on
the input variables instead of all variables.""" the input variables instead of all variables."""
vectorstore_kwargs: Optional[Dict[str, Any]] = None vectorstore_kwargs: Optional[dict[str, Any]] = None
"""Extra arguments passed to similarity_search function of the vectorstore.""" """Extra arguments passed to similarity_search function of the vectorstore."""
model_config = ConfigDict( model_config = ConfigDict(
@ -50,14 +50,14 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
@staticmethod @staticmethod
def _example_to_text( def _example_to_text(
example: Dict[str, str], input_keys: Optional[List[str]] example: dict[str, str], input_keys: Optional[list[str]]
) -> str: ) -> str:
if input_keys: if input_keys:
return " ".join(sorted_values({key: example[key] for key in input_keys})) return " ".join(sorted_values({key: example[key] for key in input_keys}))
else: else:
return " ".join(sorted_values(example)) return " ".join(sorted_values(example))
def _documents_to_examples(self, documents: List[Document]) -> List[dict]: def _documents_to_examples(self, documents: list[Document]) -> list[dict]:
# Get the examples from the metadata. # Get the examples from the metadata.
# This assumes that examples are stored in metadata. # This assumes that examples are stored in metadata.
examples = [dict(e.metadata) for e in documents] examples = [dict(e.metadata) for e in documents]
@ -66,7 +66,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
examples = [{k: eg[k] for k in self.example_keys} for eg in examples] examples = [{k: eg[k] for k in self.example_keys} for eg in examples]
return examples return examples
def add_example(self, example: Dict[str, str]) -> str: def add_example(self, example: dict[str, str]) -> str:
"""Add a new example to vectorstore. """Add a new example to vectorstore.
Args: Args:
@ -81,7 +81,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
) )
return ids[0] return ids[0]
async def aadd_example(self, example: Dict[str, str]) -> str: async def aadd_example(self, example: dict[str, str]) -> str:
"""Async add new example to vectorstore. """Async add new example to vectorstore.
Args: Args:
@ -100,7 +100,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector): class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
"""Select examples based on semantic similarity.""" """Select examples based on semantic similarity."""
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Select examples based on semantic similarity. """Select examples based on semantic similarity.
Args: Args:
@ -118,7 +118,7 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
) )
return self._documents_to_examples(example_docs) return self._documents_to_examples(example_docs)
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]: async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Asynchronously select examples based on semantic similarity. """Asynchronously select examples based on semantic similarity.
Args: Args:
@ -139,13 +139,13 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
@classmethod @classmethod
def from_examples( def from_examples(
cls, cls,
examples: List[dict], examples: list[dict],
embeddings: Embeddings, embeddings: Embeddings,
vectorstore_cls: Type[VectorStore], vectorstore_cls: type[VectorStore],
k: int = 4, k: int = 4,
input_keys: Optional[List[str]] = None, input_keys: Optional[list[str]] = None,
*, *,
example_keys: Optional[List[str]] = None, example_keys: Optional[list[str]] = None,
vectorstore_kwargs: Optional[dict] = None, vectorstore_kwargs: Optional[dict] = None,
**vectorstore_cls_kwargs: Any, **vectorstore_cls_kwargs: Any,
) -> SemanticSimilarityExampleSelector: ) -> SemanticSimilarityExampleSelector:
@ -183,13 +183,13 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
@classmethod @classmethod
async def afrom_examples( async def afrom_examples(
cls, cls,
examples: List[dict], examples: list[dict],
embeddings: Embeddings, embeddings: Embeddings,
vectorstore_cls: Type[VectorStore], vectorstore_cls: type[VectorStore],
k: int = 4, k: int = 4,
input_keys: Optional[List[str]] = None, input_keys: Optional[list[str]] = None,
*, *,
example_keys: Optional[List[str]] = None, example_keys: Optional[list[str]] = None,
vectorstore_kwargs: Optional[dict] = None, vectorstore_kwargs: Optional[dict] = None,
**vectorstore_cls_kwargs: Any, **vectorstore_cls_kwargs: Any,
) -> SemanticSimilarityExampleSelector: ) -> SemanticSimilarityExampleSelector:
@ -235,7 +235,7 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
fetch_k: int = 20 fetch_k: int = 20
"""Number of examples to fetch to rerank.""" """Number of examples to fetch to rerank."""
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Select examples based on Max Marginal Relevance. """Select examples based on Max Marginal Relevance.
Args: Args:
@ -251,7 +251,7 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
) )
return self._documents_to_examples(example_docs) return self._documents_to_examples(example_docs)
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]: async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Asynchronously select examples based on Max Marginal Relevance. """Asynchronously select examples based on Max Marginal Relevance.
Args: Args:
@ -270,13 +270,13 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
@classmethod @classmethod
def from_examples( def from_examples(
cls, cls,
examples: List[dict], examples: list[dict],
embeddings: Embeddings, embeddings: Embeddings,
vectorstore_cls: Type[VectorStore], vectorstore_cls: type[VectorStore],
k: int = 4, k: int = 4,
input_keys: Optional[List[str]] = None, input_keys: Optional[list[str]] = None,
fetch_k: int = 20, fetch_k: int = 20,
example_keys: Optional[List[str]] = None, example_keys: Optional[list[str]] = None,
vectorstore_kwargs: Optional[dict] = None, vectorstore_kwargs: Optional[dict] = None,
**vectorstore_cls_kwargs: Any, **vectorstore_cls_kwargs: Any,
) -> MaxMarginalRelevanceExampleSelector: ) -> MaxMarginalRelevanceExampleSelector:
@ -317,14 +317,14 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
@classmethod @classmethod
async def afrom_examples( async def afrom_examples(
cls, cls,
examples: List[dict], examples: list[dict],
embeddings: Embeddings, embeddings: Embeddings,
vectorstore_cls: Type[VectorStore], vectorstore_cls: type[VectorStore],
*, *,
k: int = 4, k: int = 4,
input_keys: Optional[List[str]] = None, input_keys: Optional[list[str]] = None,
fetch_k: int = 20, fetch_k: int = 20,
example_keys: Optional[List[str]] = None, example_keys: Optional[list[str]] = None,
vectorstore_kwargs: Optional[dict] = None, vectorstore_kwargs: Optional[dict] = None,
**vectorstore_cls_kwargs: Any, **vectorstore_cls_kwargs: Any,
) -> MaxMarginalRelevanceExampleSelector: ) -> MaxMarginalRelevanceExampleSelector:

View File

@ -8,7 +8,6 @@ from typing import (
Collection, Collection,
Iterable, Iterable,
Iterator, Iterator,
List,
Optional, Optional,
) )
@ -68,7 +67,7 @@ class Node(Serializable):
"""Text contained by the node.""" """Text contained by the node."""
metadata: dict = Field(default_factory=dict) metadata: dict = Field(default_factory=dict)
"""Metadata for the node.""" """Metadata for the node."""
links: List[Link] = Field(default_factory=list) links: list[Link] = Field(default_factory=list)
"""Links associated with the node.""" """Links associated with the node."""
@ -189,7 +188,7 @@ class GraphVectorStore(VectorStore):
*, *,
ids: Optional[Iterable[str]] = None, ids: Optional[Iterable[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> list[str]:
"""Run more texts through the embeddings and add to the vectorstore. """Run more texts through the embeddings and add to the vectorstore.
The Links present in the metadata field `links` will be extracted to create The Links present in the metadata field `links` will be extracted to create
@ -237,7 +236,7 @@ class GraphVectorStore(VectorStore):
*, *,
ids: Optional[Iterable[str]] = None, ids: Optional[Iterable[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> list[str]:
"""Run more texts through the embeddings and add to the vectorstore. """Run more texts through the embeddings and add to the vectorstore.
The Links present in the metadata field `links` will be extracted to create The Links present in the metadata field `links` will be extracted to create
@ -282,7 +281,7 @@ class GraphVectorStore(VectorStore):
self, self,
documents: Iterable[Document], documents: Iterable[Document],
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> list[str]:
"""Run more documents through the embeddings and add to the vectorstore. """Run more documents through the embeddings and add to the vectorstore.
The Links present in the document metadata field `links` will be extracted to The Links present in the document metadata field `links` will be extracted to
@ -332,7 +331,7 @@ class GraphVectorStore(VectorStore):
self, self,
documents: Iterable[Document], documents: Iterable[Document],
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> list[str]:
"""Run more documents through the embeddings and add to the vectorstore. """Run more documents through the embeddings and add to the vectorstore.
The Links present in the document metadata field `links` will be extracted to The Links present in the document metadata field `links` will be extracted to
@ -535,7 +534,7 @@ class GraphVectorStore(VectorStore):
def similarity_search( def similarity_search(
self, query: str, k: int = 4, **kwargs: Any self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]: ) -> list[Document]:
return list(self.traversal_search(query, k=k, depth=0)) return list(self.traversal_search(query, k=k, depth=0))
def max_marginal_relevance_search( def max_marginal_relevance_search(
@ -545,7 +544,7 @@ class GraphVectorStore(VectorStore):
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
return list( return list(
self.mmr_traversal_search( self.mmr_traversal_search(
query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0 query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0
@ -554,10 +553,10 @@ class GraphVectorStore(VectorStore):
async def asimilarity_search( async def asimilarity_search(
self, query: str, k: int = 4, **kwargs: Any self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]: ) -> list[Document]:
return [doc async for doc in self.atraversal_search(query, k=k, depth=0)] return [doc async for doc in self.atraversal_search(query, k=k, depth=0)]
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]: def search(self, query: str, search_type: str, **kwargs: Any) -> list[Document]:
if search_type == "similarity": if search_type == "similarity":
return self.similarity_search(query, **kwargs) return self.similarity_search(query, **kwargs)
elif search_type == "similarity_score_threshold": elif search_type == "similarity_score_threshold":
@ -580,7 +579,7 @@ class GraphVectorStore(VectorStore):
async def asearch( async def asearch(
self, query: str, search_type: str, **kwargs: Any self, query: str, search_type: str, **kwargs: Any
) -> List[Document]: ) -> list[Document]:
if search_type == "similarity": if search_type == "similarity":
return await self.asimilarity_search(query, **kwargs) return await self.asimilarity_search(query, **kwargs)
elif search_type == "similarity_score_threshold": elif search_type == "similarity_score_threshold":
@ -679,7 +678,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]: ) -> list[Document]:
if self.search_type == "traversal": if self.search_type == "traversal":
return list(self.vectorstore.traversal_search(query, **self.search_kwargs)) return list(self.vectorstore.traversal_search(query, **self.search_kwargs))
elif self.search_type == "mmr_traversal": elif self.search_type == "mmr_traversal":
@ -691,7 +690,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]: ) -> list[Document]:
if self.search_type == "traversal": if self.search_type == "traversal":
return [ return [
doc doc

View File

@ -11,14 +11,11 @@ from typing import (
AsyncIterable, AsyncIterable,
AsyncIterator, AsyncIterator,
Callable, Callable,
Dict,
Iterable, Iterable,
Iterator, Iterator,
List,
Literal, Literal,
Optional, Optional,
Sequence, Sequence,
Set,
TypedDict, TypedDict,
TypeVar, TypeVar,
Union, Union,
@ -71,7 +68,7 @@ class _HashedDocument(Document):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def calculate_hashes(cls, values: Dict[str, Any]) -> Any: def calculate_hashes(cls, values: dict[str, Any]) -> Any:
"""Root validator to calculate content and metadata hash.""" """Root validator to calculate content and metadata hash."""
content = values.get("page_content", "") content = values.get("page_content", "")
metadata = values.get("metadata", {}) metadata = values.get("metadata", {})
@ -125,7 +122,7 @@ class _HashedDocument(Document):
) )
def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]: def _batch(size: int, iterable: Iterable[T]) -> Iterator[list[T]]:
"""Utility batching function.""" """Utility batching function."""
it = iter(iterable) it = iter(iterable)
while True: while True:
@ -135,9 +132,9 @@ def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
yield chunk yield chunk
async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[List[T]]: async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[list[T]]:
"""Utility batching function.""" """Utility batching function."""
batch: List[T] = [] batch: list[T] = []
async for element in iterable: async for element in iterable:
if len(batch) < size: if len(batch) < size:
batch.append(element) batch.append(element)
@ -171,7 +168,7 @@ def _deduplicate_in_order(
hashed_documents: Iterable[_HashedDocument], hashed_documents: Iterable[_HashedDocument],
) -> Iterator[_HashedDocument]: ) -> Iterator[_HashedDocument]:
"""Deduplicate a list of hashed documents while preserving order.""" """Deduplicate a list of hashed documents while preserving order."""
seen: Set[str] = set() seen: set[str] = set()
for hashed_doc in hashed_documents: for hashed_doc in hashed_documents:
if hashed_doc.hash_ not in seen: if hashed_doc.hash_ not in seen:
@ -349,7 +346,7 @@ def index(
uids = [] uids = []
docs_to_index = [] docs_to_index = []
uids_to_refresh = [] uids_to_refresh = []
seen_docs: Set[str] = set() seen_docs: set[str] = set()
for hashed_doc, doc_exists in zip(hashed_docs, exists_batch): for hashed_doc, doc_exists in zip(hashed_docs, exists_batch):
if doc_exists: if doc_exists:
if force_update: if force_update:
@ -589,7 +586,7 @@ async def aindex(
uids: list[str] = [] uids: list[str] = []
docs_to_index: list[Document] = [] docs_to_index: list[Document] = []
uids_to_refresh = [] uids_to_refresh = []
seen_docs: Set[str] = set() seen_docs: set[str] = set()
for hashed_doc, doc_exists in zip(hashed_docs, exists_batch): for hashed_doc, doc_exists in zip(hashed_docs, exists_batch):
if doc_exists: if doc_exists:
if force_update: if force_update:

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import abc import abc
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Sequence, TypedDict from typing import Any, Optional, Sequence, TypedDict
from langchain_core._api import beta from langchain_core._api import beta
from langchain_core.documents import Document from langchain_core.documents import Document
@ -144,7 +144,7 @@ class RecordManager(ABC):
""" """
@abstractmethod @abstractmethod
def exists(self, keys: Sequence[str]) -> List[bool]: def exists(self, keys: Sequence[str]) -> list[bool]:
"""Check if the provided keys exist in the database. """Check if the provided keys exist in the database.
Args: Args:
@ -155,7 +155,7 @@ class RecordManager(ABC):
""" """
@abstractmethod @abstractmethod
async def aexists(self, keys: Sequence[str]) -> List[bool]: async def aexists(self, keys: Sequence[str]) -> list[bool]:
"""Asynchronously check if the provided keys exist in the database. """Asynchronously check if the provided keys exist in the database.
Args: Args:
@ -173,7 +173,7 @@ class RecordManager(ABC):
after: Optional[float] = None, after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None, group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
) -> List[str]: ) -> list[str]:
"""List records in the database based on the provided filters. """List records in the database based on the provided filters.
Args: Args:
@ -194,7 +194,7 @@ class RecordManager(ABC):
after: Optional[float] = None, after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None, group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
) -> List[str]: ) -> list[str]:
"""Asynchronously list records in the database based on the provided filters. """Asynchronously list records in the database based on the provided filters.
Args: Args:
@ -241,7 +241,7 @@ class InMemoryRecordManager(RecordManager):
super().__init__(namespace) super().__init__(namespace)
# Each key points to a dictionary # Each key points to a dictionary
# of {'group_id': group_id, 'updated_at': timestamp} # of {'group_id': group_id, 'updated_at': timestamp}
self.records: Dict[str, _Record] = {} self.records: dict[str, _Record] = {}
self.namespace = namespace self.namespace = namespace
def create_schema(self) -> None: def create_schema(self) -> None:
@ -325,7 +325,7 @@ class InMemoryRecordManager(RecordManager):
""" """
self.update(keys, group_ids=group_ids, time_at_least=time_at_least) self.update(keys, group_ids=group_ids, time_at_least=time_at_least)
def exists(self, keys: Sequence[str]) -> List[bool]: def exists(self, keys: Sequence[str]) -> list[bool]:
"""Check if the provided keys exist in the database. """Check if the provided keys exist in the database.
Args: Args:
@ -336,7 +336,7 @@ class InMemoryRecordManager(RecordManager):
""" """
return [key in self.records for key in keys] return [key in self.records for key in keys]
async def aexists(self, keys: Sequence[str]) -> List[bool]: async def aexists(self, keys: Sequence[str]) -> list[bool]:
"""Async check if the provided keys exist in the database. """Async check if the provided keys exist in the database.
Args: Args:
@ -354,7 +354,7 @@ class InMemoryRecordManager(RecordManager):
after: Optional[float] = None, after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None, group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
) -> List[str]: ) -> list[str]:
"""List records in the database based on the provided filters. """List records in the database based on the provided filters.
Args: Args:
@ -390,7 +390,7 @@ class InMemoryRecordManager(RecordManager):
after: Optional[float] = None, after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None, group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
) -> List[str]: ) -> list[str]:
"""Async list records in the database based on the provided filters. """Async list records in the database based on the provided filters.
Args: Args:
@ -449,9 +449,9 @@ class UpsertResponse(TypedDict):
indexed to avoid this issue. indexed to avoid this issue.
""" """
succeeded: List[str] succeeded: list[str]
"""The IDs that were successfully indexed.""" """The IDs that were successfully indexed."""
failed: List[str] failed: list[str]
"""The IDs that failed to index.""" """The IDs that failed to index."""
@ -562,7 +562,7 @@ class DocumentIndex(BaseRetriever):
) )
@abc.abstractmethod @abc.abstractmethod
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> DeleteResponse: def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse:
"""Delete by IDs or other criteria. """Delete by IDs or other criteria.
Calling delete without any input parameters should raise a ValueError! Calling delete without any input parameters should raise a ValueError!
@ -579,7 +579,7 @@ class DocumentIndex(BaseRetriever):
""" """
async def adelete( async def adelete(
self, ids: Optional[List[str]] = None, **kwargs: Any self, ids: Optional[list[str]] = None, **kwargs: Any
) -> DeleteResponse: ) -> DeleteResponse:
"""Delete by IDs or other criteria. Async variant. """Delete by IDs or other criteria. Async variant.
@ -607,7 +607,7 @@ class DocumentIndex(BaseRetriever):
ids: Sequence[str], ids: Sequence[str],
/, /,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
"""Get documents by id. """Get documents by id.
Fewer documents may be returned than requested if some IDs are not found or Fewer documents may be returned than requested if some IDs are not found or
@ -633,7 +633,7 @@ class DocumentIndex(BaseRetriever):
ids: Sequence[str], ids: Sequence[str],
/, /,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
"""Get documents by id. """Get documents by id.
Fewer documents may be returned than requested if some IDs are not found or Fewer documents may be returned than requested if some IDs are not found or

View File

@ -6,14 +6,11 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable, Callable,
Dict,
List, List,
Literal, Literal,
Mapping, Mapping,
Optional, Optional,
Sequence, Sequence,
Set,
Type,
TypeVar, TypeVar,
Union, Union,
) )
@ -51,7 +48,7 @@ class LangSmithParams(TypedDict, total=False):
"""Temperature for generation.""" """Temperature for generation."""
ls_max_tokens: Optional[int] ls_max_tokens: Optional[int]
"""Max tokens for generation.""" """Max tokens for generation."""
ls_stop: Optional[List[str]] ls_stop: Optional[list[str]]
"""Stop words for generation.""" """Stop words for generation."""
@ -74,7 +71,7 @@ def get_tokenizer() -> Any:
return GPT2TokenizerFast.from_pretrained("gpt2") return GPT2TokenizerFast.from_pretrained("gpt2")
def _get_token_ids_default_method(text: str) -> List[int]: def _get_token_ids_default_method(text: str) -> list[int]:
"""Encode the text into token IDs.""" """Encode the text into token IDs."""
# get the cached tokenizer # get the cached tokenizer
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
@ -117,11 +114,11 @@ class BaseLanguageModel(
"""Whether to print out response text.""" """Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True) callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to add to the run trace.""" """Callbacks to add to the run trace."""
tags: Optional[List[str]] = Field(default=None, exclude=True) tags: Optional[list[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace.""" """Tags to add to the run trace."""
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True) metadata: Optional[dict[str, Any]] = Field(default=None, exclude=True)
"""Metadata to add to the run trace.""" """Metadata to add to the run trace."""
custom_get_token_ids: Optional[Callable[[str], List[int]]] = Field( custom_get_token_ids: Optional[Callable[[str], list[int]]] = Field(
default=None, exclude=True default=None, exclude=True
) )
"""Optional encoder to use for counting tokens.""" """Optional encoder to use for counting tokens."""
@ -167,8 +164,8 @@ class BaseLanguageModel(
@abstractmethod @abstractmethod
def generate_prompt( def generate_prompt(
self, self,
prompts: List[PromptValue], prompts: list[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
@ -202,8 +199,8 @@ class BaseLanguageModel(
@abstractmethod @abstractmethod
async def agenerate_prompt( async def agenerate_prompt(
self, self,
prompts: List[PromptValue], prompts: list[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
@ -235,8 +232,8 @@ class BaseLanguageModel(
""" """
def with_structured_output( def with_structured_output(
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any self, schema: Union[dict, type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Not implemented on this class.""" """Not implemented on this class."""
# Implement this on child class if there is a way of steering the model to # Implement this on child class if there is a way of steering the model to
# generate responses that match a given schema. # generate responses that match a given schema.
@ -267,7 +264,7 @@ class BaseLanguageModel(
@abstractmethod @abstractmethod
def predict_messages( def predict_messages(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
*, *,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
**kwargs: Any, **kwargs: Any,
@ -313,7 +310,7 @@ class BaseLanguageModel(
@abstractmethod @abstractmethod
async def apredict_messages( async def apredict_messages(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
*, *,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
**kwargs: Any, **kwargs: Any,
@ -339,7 +336,7 @@ class BaseLanguageModel(
"""Get the identifying parameters.""" """Get the identifying parameters."""
return self.lc_attributes return self.lc_attributes
def get_token_ids(self, text: str) -> List[int]: def get_token_ids(self, text: str) -> list[int]:
"""Return the ordered ids of the tokens in a text. """Return the ordered ids of the tokens in a text.
Args: Args:
@ -367,7 +364,7 @@ class BaseLanguageModel(
""" """
return len(self.get_token_ids(text)) return len(self.get_token_ids(text))
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
"""Get the number of tokens in the messages. """Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window. Useful for checking if an input fits in a model's context window.
@ -381,7 +378,7 @@ class BaseLanguageModel(
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages]) return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
@classmethod @classmethod
def _all_required_field_names(cls) -> Set: def _all_required_field_names(cls) -> set:
"""DEPRECATED: Kept for backwards compatibility. """DEPRECATED: Kept for backwards compatibility.
Use get_pydantic_field_names. Use get_pydantic_field_names.

View File

@ -15,11 +15,9 @@ from typing import (
Callable, Callable,
Dict, Dict,
Iterator, Iterator,
List,
Literal, Literal,
Optional, Optional,
Sequence, Sequence,
Type,
Union, Union,
cast, cast,
) )
@ -223,7 +221,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def raise_deprecation(cls, values: Dict) -> Any: def raise_deprecation(cls, values: dict) -> Any:
"""Raise deprecation warning if callback_manager is used. """Raise deprecation warning if callback_manager is used.
Args: Args:
@ -277,7 +275,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
input: LanguageModelInput, input: LanguageModelInput,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
config = ensure_config(config) config = ensure_config(config)
@ -300,7 +298,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
input: LanguageModelInput, input: LanguageModelInput,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
config = ensure_config(config) config = ensure_config(config)
@ -356,7 +354,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
input: LanguageModelInput, input: LanguageModelInput,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[BaseMessageChunk]: ) -> Iterator[BaseMessageChunk]:
if not self._should_stream(async_api=False, **{**kwargs, **{"stream": True}}): if not self._should_stream(async_api=False, **{**kwargs, **{"stream": True}}):
@ -426,7 +424,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
input: LanguageModelInput, input: LanguageModelInput,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]: ) -> AsyncIterator[BaseMessageChunk]:
if not self._should_stream(async_api=True, **{**kwargs, **{"stream": True}}): if not self._should_stream(async_api=True, **{**kwargs, **{"stream": True}}):
@ -499,12 +497,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
# --- Custom methods --- # --- Custom methods ---
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
return {} return {}
def _get_invocation_params( def _get_invocation_params(
self, self,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> dict: ) -> dict:
params = self.dict() params = self.dict()
@ -513,7 +511,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _get_ls_params( def _get_ls_params(
self, self,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> LangSmithParams: ) -> LangSmithParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
@ -550,7 +548,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
return ls_params return ls_params
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str: def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str:
if self.is_lc_serializable(): if self.is_lc_serializable():
params = {**kwargs, **{"stop": stop}} params = {**kwargs, **{"stop": stop}}
param_string = str(sorted([(k, v) for k, v in params.items()])) param_string = str(sorted([(k, v) for k, v in params.items()]))
@ -567,12 +565,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def generate( def generate(
self, self,
messages: List[List[BaseMessage]], messages: list[list[BaseMessage]],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
*, *,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None, run_id: Optional[uuid.UUID] = None,
**kwargs: Any, **kwargs: Any,
@ -658,12 +656,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
async def agenerate( async def agenerate(
self, self,
messages: List[List[BaseMessage]], messages: list[list[BaseMessage]],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
*, *,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None, run_id: Optional[uuid.UUID] = None,
**kwargs: Any, **kwargs: Any,
@ -777,8 +775,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def generate_prompt( def generate_prompt(
self, self,
prompts: List[PromptValue], prompts: list[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
@ -787,8 +785,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
async def agenerate_prompt( async def agenerate_prompt(
self, self,
prompts: List[PromptValue], prompts: list[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
@ -799,8 +797,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _generate_with_cache( def _generate_with_cache(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -839,7 +837,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager=run_manager, run_manager=run_manager,
**kwargs, **kwargs,
): ):
chunks: List[ChatGenerationChunk] = [] chunks: list[ChatGenerationChunk] = []
for chunk in self._stream(messages, stop=stop, **kwargs): for chunk in self._stream(messages, stop=stop, **kwargs):
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
if run_manager: if run_manager:
@ -876,8 +874,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
async def _agenerate_with_cache( async def _agenerate_with_cache(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -916,7 +914,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager=run_manager, run_manager=run_manager,
**kwargs, **kwargs,
): ):
chunks: List[ChatGenerationChunk] = [] chunks: list[ChatGenerationChunk] = []
async for chunk in self._astream(messages, stop=stop, **kwargs): async for chunk in self._astream(messages, stop=stop, **kwargs):
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
if run_manager: if run_manager:
@ -954,8 +952,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
@abstractmethod @abstractmethod
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -963,8 +961,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
async def _agenerate( async def _agenerate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -980,8 +978,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _stream( def _stream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
@ -989,8 +987,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
async def _astream( async def _astream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
@ -1017,8 +1015,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
@deprecated("0.1.7", alternative="invoke", removal="1.0") @deprecated("0.1.7", alternative="invoke", removal="1.0")
def __call__( def __call__(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
@ -1032,8 +1030,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
async def _call_async( async def _call_async(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
@ -1048,7 +1046,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
@deprecated("0.1.7", alternative="invoke", removal="1.0") @deprecated("0.1.7", alternative="invoke", removal="1.0")
def call_as_llm( def call_as_llm(
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any self, message: str, stop: Optional[list[str]] = None, **kwargs: Any
) -> str: ) -> str:
return self.predict(message, stop=stop, **kwargs) return self.predict(message, stop=stop, **kwargs)
@ -1069,7 +1067,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
@deprecated("0.1.7", alternative="invoke", removal="1.0") @deprecated("0.1.7", alternative="invoke", removal="1.0")
def predict_messages( def predict_messages(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
*, *,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
**kwargs: Any, **kwargs: Any,
@ -1099,7 +1097,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
@deprecated("0.1.7", alternative="ainvoke", removal="1.0") @deprecated("0.1.7", alternative="ainvoke", removal="1.0")
async def apredict_messages( async def apredict_messages(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
*, *,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
**kwargs: Any, **kwargs: Any,
@ -1115,7 +1113,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of chat model.""" """Return type of chat model."""
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> dict:
"""Return a dictionary of the LLM.""" """Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params) starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type starter_dict["_type"] = self._llm_type
@ -1123,18 +1121,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def bind_tools( def bind_tools(
self, self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], tools: Sequence[Union[Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
raise NotImplementedError() raise NotImplementedError()
def with_structured_output( def with_structured_output(
self, self,
schema: Union[Dict, Type], schema: Union[Dict, type], # noqa: UP006
*, *,
include_raw: bool = False, include_raw: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: # noqa: UP006
"""Model wrapper that returns outputs formatted to match the given schema. """Model wrapper that returns outputs formatted to match the given schema.
Args: Args:
@ -1281,8 +1279,8 @@ class SimpleChatModel(BaseChatModel):
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -1294,8 +1292,8 @@ class SimpleChatModel(BaseChatModel):
@abstractmethod @abstractmethod
def _call( def _call(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
@ -1303,8 +1301,8 @@ class SimpleChatModel(BaseChatModel):
async def _agenerate( async def _agenerate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:

View File

@ -20,8 +20,6 @@ from typing import (
List, List,
Optional, Optional,
Sequence, Sequence,
Tuple,
Type,
Union, Union,
cast, cast,
) )
@ -76,7 +74,7 @@ def _log_error_once(msg: str) -> None:
def create_base_retry_decorator( def create_base_retry_decorator(
error_types: List[Type[BaseException]], error_types: list[type[BaseException]],
max_retries: int = 1, max_retries: int = 1,
run_manager: Optional[ run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
@ -153,10 +151,10 @@ def _resolve_cache(cache: Union[BaseCache, bool, None]) -> Optional[BaseCache]:
def get_prompts( def get_prompts(
params: Dict[str, Any], params: dict[str, Any],
prompts: List[str], prompts: list[str],
cache: Optional[Union[BaseCache, bool, None]] = None, cache: Optional[Union[BaseCache, bool, None]] = None,
) -> Tuple[Dict[int, List], str, List[int], List[str]]: ) -> tuple[dict[int, list], str, list[int], list[str]]:
"""Get prompts that are already cached. """Get prompts that are already cached.
Args: Args:
@ -189,10 +187,10 @@ def get_prompts(
async def aget_prompts( async def aget_prompts(
params: Dict[str, Any], params: dict[str, Any],
prompts: List[str], prompts: list[str],
cache: Optional[Union[BaseCache, bool, None]] = None, cache: Optional[Union[BaseCache, bool, None]] = None,
) -> Tuple[Dict[int, List], str, List[int], List[str]]: ) -> tuple[dict[int, list], str, list[int], list[str]]:
"""Get prompts that are already cached. Async version. """Get prompts that are already cached. Async version.
Args: Args:
@ -225,11 +223,11 @@ async def aget_prompts(
def update_cache( def update_cache(
cache: Union[BaseCache, bool, None], cache: Union[BaseCache, bool, None],
existing_prompts: Dict[int, List], existing_prompts: dict[int, list],
llm_string: str, llm_string: str,
missing_prompt_idxs: List[int], missing_prompt_idxs: list[int],
new_results: LLMResult, new_results: LLMResult,
prompts: List[str], prompts: list[str],
) -> Optional[dict]: ) -> Optional[dict]:
"""Update the cache and get the LLM output. """Update the cache and get the LLM output.
@ -259,11 +257,11 @@ def update_cache(
async def aupdate_cache( async def aupdate_cache(
cache: Union[BaseCache, bool, None], cache: Union[BaseCache, bool, None],
existing_prompts: Dict[int, List], existing_prompts: dict[int, list],
llm_string: str, llm_string: str,
missing_prompt_idxs: List[int], missing_prompt_idxs: list[int],
new_results: LLMResult, new_results: LLMResult,
prompts: List[str], prompts: list[str],
) -> Optional[dict]: ) -> Optional[dict]:
"""Update the cache and get the LLM output. Async version. """Update the cache and get the LLM output. Async version.
@ -306,7 +304,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def raise_deprecation(cls, values: Dict) -> Any: def raise_deprecation(cls, values: dict) -> Any:
"""Raise deprecation warning if callback_manager is used.""" """Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None: if values.get("callback_manager") is not None:
warnings.warn( warnings.warn(
@ -324,7 +322,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
# --- Runnable methods --- # --- Runnable methods ---
@property @property
def OutputType(self) -> Type[str]: def OutputType(self) -> type[str]:
"""Get the input type for this runnable.""" """Get the input type for this runnable."""
return str return str
@ -343,7 +341,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def _get_ls_params( def _get_ls_params(
self, self,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> LangSmithParams: ) -> LangSmithParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
@ -383,7 +381,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
input: LanguageModelInput, input: LanguageModelInput,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
config = ensure_config(config) config = ensure_config(config)
@ -407,7 +405,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
input: LanguageModelInput, input: LanguageModelInput,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
config = ensure_config(config) config = ensure_config(config)
@ -425,12 +423,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def batch( def batch(
self, self,
inputs: List[LanguageModelInput], inputs: list[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> list[str]:
if not inputs: if not inputs:
return [] return []
@ -472,12 +470,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
async def abatch( async def abatch(
self, self,
inputs: List[LanguageModelInput], inputs: list[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> list[str]:
if not inputs: if not inputs:
return [] return []
config = get_config_list(config, len(inputs)) config = get_config_list(config, len(inputs))
@ -521,7 +519,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
input: LanguageModelInput, input: LanguageModelInput,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[str]: ) -> Iterator[str]:
if type(self)._stream == BaseLLM._stream: if type(self)._stream == BaseLLM._stream:
@ -583,7 +581,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
input: LanguageModelInput, input: LanguageModelInput,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
if ( if (
@ -649,8 +647,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
@abstractmethod @abstractmethod
def _generate( def _generate(
self, self,
prompts: List[str], prompts: list[str],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
@ -658,8 +656,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
async def _agenerate( async def _agenerate(
self, self,
prompts: List[str], prompts: list[str],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
@ -676,7 +674,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def _stream( def _stream(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[GenerationChunk]: ) -> Iterator[GenerationChunk]:
@ -704,7 +702,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
async def _astream( async def _astream(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[GenerationChunk]: ) -> AsyncIterator[GenerationChunk]:
@ -747,9 +745,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def generate_prompt( def generate_prompt(
self, self,
prompts: List[PromptValue], prompts: list[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts] prompt_strings = [p.to_string() for p in prompts]
@ -757,9 +755,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
async def agenerate_prompt( async def agenerate_prompt(
self, self,
prompts: List[PromptValue], prompts: list[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts] prompt_strings = [p.to_string() for p in prompts]
@ -769,9 +767,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def _generate_helper( def _generate_helper(
self, self,
prompts: List[str], prompts: list[str],
stop: Optional[List[str]], stop: Optional[list[str]],
run_managers: List[CallbackManagerForLLMRun], run_managers: list[CallbackManagerForLLMRun],
new_arg_supported: bool, new_arg_supported: bool,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
@ -802,14 +800,14 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def generate( def generate(
self, self,
prompts: List[str], prompts: list[str],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None,
*, *,
tags: Optional[Union[List[str], List[List[str]]]] = None, tags: Optional[Union[list[str], list[list[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, metadata: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None,
run_name: Optional[Union[str, List[str]]] = None, run_name: Optional[Union[str, list[str]]] = None,
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None, run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Pass a sequence of prompts to a model and return generations. """Pass a sequence of prompts to a model and return generations.
@ -987,7 +985,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
@staticmethod @staticmethod
def _get_run_ids_list( def _get_run_ids_list(
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]], prompts: list run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]], prompts: list
) -> list: ) -> list:
if run_id is None: if run_id is None:
return [None] * len(prompts) return [None] * len(prompts)
@ -1002,9 +1000,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
async def _agenerate_helper( async def _agenerate_helper(
self, self,
prompts: List[str], prompts: list[str],
stop: Optional[List[str]], stop: Optional[list[str]],
run_managers: List[AsyncCallbackManagerForLLMRun], run_managers: list[AsyncCallbackManagerForLLMRun],
new_arg_supported: bool, new_arg_supported: bool,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
@ -1044,14 +1042,14 @@ class BaseLLM(BaseLanguageModel[str], ABC):
async def agenerate( async def agenerate(
self, self,
prompts: List[str], prompts: list[str],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None,
*, *,
tags: Optional[Union[List[str], List[List[str]]]] = None, tags: Optional[Union[list[str], list[list[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, metadata: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None,
run_name: Optional[Union[str, List[str]]] = None, run_name: Optional[Union[str, list[str]]] = None,
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None, run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Asynchronously pass a sequence of prompts to a model and return generations. """Asynchronously pass a sequence of prompts to a model and return generations.
@ -1239,11 +1237,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def __call__( def __call__(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
*, *,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Check Cache and run the LLM on the given prompt and input. """Check Cache and run the LLM on the given prompt and input.
@ -1287,11 +1285,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
async def _call_async( async def _call_async(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
*, *,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Check Cache and run the LLM on the given prompt and input.""" """Check Cache and run the LLM on the given prompt and input."""
@ -1318,7 +1316,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
@deprecated("0.1.7", alternative="invoke", removal="1.0") @deprecated("0.1.7", alternative="invoke", removal="1.0")
def predict_messages( def predict_messages(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
*, *,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
**kwargs: Any, **kwargs: Any,
@ -1344,7 +1342,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
@deprecated("0.1.7", alternative="ainvoke", removal="1.0") @deprecated("0.1.7", alternative="ainvoke", removal="1.0")
async def apredict_messages( async def apredict_messages(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
*, *,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
**kwargs: Any, **kwargs: Any,
@ -1367,7 +1365,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of llm.""" """Return type of llm."""
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> dict:
"""Return a dictionary of the LLM.""" """Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params) starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type starter_dict["_type"] = self._llm_type
@ -1443,7 +1441,7 @@ class LLM(BaseLLM):
def _call( def _call(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
@ -1467,7 +1465,7 @@ class LLM(BaseLLM):
async def _acall( async def _acall(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
@ -1500,8 +1498,8 @@ class LLM(BaseLLM):
def _generate( def _generate(
self, self,
prompts: List[str], prompts: list[str],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
@ -1520,8 +1518,8 @@ class LLM(BaseLLM):
async def _agenerate( async def _agenerate(
self, self,
prompts: List[str], prompts: list[str],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:

View File

@ -11,7 +11,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any
from pydantic import ConfigDict from pydantic import ConfigDict
@ -55,11 +55,11 @@ class BaseMemory(Serializable, ABC):
@property @property
@abstractmethod @abstractmethod
def memory_variables(self) -> List[str]: def memory_variables(self) -> list[str]:
"""The string keys this memory class will add to chain inputs.""" """The string keys this memory class will add to chain inputs."""
@abstractmethod @abstractmethod
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return key-value pairs given the text input to the chain. """Return key-value pairs given the text input to the chain.
Args: Args:
@ -69,7 +69,7 @@ class BaseMemory(Serializable, ABC):
A dictionary of key-value pairs. A dictionary of key-value pairs.
""" """
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Async return key-value pairs given the text input to the chain. """Async return key-value pairs given the text input to the chain.
Args: Args:
@ -81,7 +81,7 @@ class BaseMemory(Serializable, ABC):
return await run_in_executor(None, self.load_memory_variables, inputs) return await run_in_executor(None, self.load_memory_variables, inputs)
@abstractmethod @abstractmethod
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
"""Save the context of this chain run to memory. """Save the context of this chain run to memory.
Args: Args:
@ -90,7 +90,7 @@ class BaseMemory(Serializable, ABC):
""" """
async def asave_context( async def asave_context(
self, inputs: Dict[str, Any], outputs: Dict[str, str] self, inputs: dict[str, Any], outputs: dict[str, str]
) -> None: ) -> None:
"""Async save the context of this chain run to memory. """Async save the context of this chain run to memory.

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union, cast
from pydantic import ConfigDict, Field, field_validator from pydantic import ConfigDict, Field, field_validator
@ -19,7 +19,7 @@ class BaseMessage(Serializable):
Messages are the inputs and outputs of ChatModels. Messages are the inputs and outputs of ChatModels.
""" """
content: Union[str, List[Union[str, Dict]]] content: Union[str, list[Union[str, dict]]]
"""The string contents of the message.""" """The string contents of the message."""
additional_kwargs: dict = Field(default_factory=dict) additional_kwargs: dict = Field(default_factory=dict)
@ -64,7 +64,7 @@ class BaseMessage(Serializable):
return id_value return id_value
def __init__( def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
) -> None: ) -> None:
"""Pass in content as positional arg. """Pass in content as positional arg.
@ -85,7 +85,7 @@ class BaseMessage(Serializable):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]. Default is ["langchain", "schema", "messages"].
""" """
@ -119,9 +119,9 @@ class BaseMessage(Serializable):
def merge_content( def merge_content(
first_content: Union[str, List[Union[str, Dict]]], first_content: Union[str, list[Union[str, dict]]],
*contents: Union[str, List[Union[str, Dict]]], *contents: Union[str, list[Union[str, dict]]],
) -> Union[str, List[Union[str, Dict]]]: ) -> Union[str, list[Union[str, dict]]]:
"""Merge two message contents. """Merge two message contents.
Args: Args:
@ -163,7 +163,7 @@ class BaseMessageChunk(BaseMessage):
"""Message chunk, which can be concatenated with other Message chunks.""" """Message chunk, which can be concatenated with other Message chunks."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]. Default is ["langchain", "schema", "messages"].
""" """
@ -242,7 +242,7 @@ def message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.model_dump()} return {"type": message.type, "data": message.model_dump()}
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]: def messages_to_dict(messages: Sequence[BaseMessage]) -> list[dict]:
"""Convert a sequence of Messages to a list of dictionaries. """Convert a sequence of Messages to a list of dictionaries.
Args: Args:

View File

@ -23,7 +23,6 @@ from typing import (
Optional, Optional,
Sequence, Sequence,
Tuple, Tuple,
Type,
Union, Union,
cast, cast,
overload, overload,
@ -166,7 +165,7 @@ def _message_from_dict(message: dict) -> BaseMessage:
raise ValueError(f"Got unexpected message type: {_type}") raise ValueError(f"Got unexpected message type: {_type}")
def messages_from_dict(messages: Sequence[dict]) -> List[BaseMessage]: def messages_from_dict(messages: Sequence[dict]) -> list[BaseMessage]:
"""Convert a sequence of messages from dicts to Message objects. """Convert a sequence of messages from dicts to Message objects.
Args: Args:
@ -208,7 +207,7 @@ def _create_message_from_message_type(
content: str, content: str,
name: Optional[str] = None, name: Optional[str] = None,
tool_call_id: Optional[str] = None, tool_call_id: Optional[str] = None,
tool_calls: Optional[List[Dict[str, Any]]] = None, tool_calls: Optional[list[dict[str, Any]]] = None,
id: Optional[str] = None, id: Optional[str] = None,
**additional_kwargs: Any, **additional_kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
@ -230,7 +229,7 @@ def _create_message_from_message_type(
ValueError: if the message type is not one of "human", "user", "ai", ValueError: if the message type is not one of "human", "user", "ai",
"assistant", "system", "function", or "tool". "assistant", "system", "function", or "tool".
""" """
kwargs: Dict[str, Any] = {} kwargs: dict[str, Any] = {}
if name is not None: if name is not None:
kwargs["name"] = name kwargs["name"] = name
if tool_call_id is not None: if tool_call_id is not None:
@ -331,7 +330,7 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
def convert_to_messages( def convert_to_messages(
messages: Union[Iterable[MessageLikeRepresentation], PromptValue], messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
) -> List[BaseMessage]: ) -> list[BaseMessage]:
"""Convert a sequence of messages to a list of messages. """Convert a sequence of messages to a list of messages.
Args: Args:
@ -352,18 +351,18 @@ def _runnable_support(func: Callable) -> Callable:
@overload @overload
def wrapped( def wrapped(
messages: Literal[None] = None, **kwargs: Any messages: Literal[None] = None, **kwargs: Any
) -> Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]]: ... ) -> Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]: ...
@overload @overload
def wrapped( def wrapped(
messages: Sequence[MessageLikeRepresentation], **kwargs: Any messages: Sequence[MessageLikeRepresentation], **kwargs: Any
) -> List[BaseMessage]: ... ) -> list[BaseMessage]: ...
def wrapped( def wrapped(
messages: Optional[Sequence[MessageLikeRepresentation]] = None, **kwargs: Any messages: Optional[Sequence[MessageLikeRepresentation]] = None, **kwargs: Any
) -> Union[ ) -> Union[
List[BaseMessage], list[BaseMessage],
Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]], Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]],
]: ]:
from langchain_core.runnables.base import RunnableLambda from langchain_core.runnables.base import RunnableLambda
@ -382,11 +381,11 @@ def filter_messages(
*, *,
include_names: Optional[Sequence[str]] = None, include_names: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None, exclude_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[Union[str, Type[BaseMessage]]]] = None, include_types: Optional[Sequence[Union[str, type[BaseMessage]]]] = None,
exclude_types: Optional[Sequence[Union[str, Type[BaseMessage]]]] = None, exclude_types: Optional[Sequence[Union[str, type[BaseMessage]]]] = None,
include_ids: Optional[Sequence[str]] = None, include_ids: Optional[Sequence[str]] = None,
exclude_ids: Optional[Sequence[str]] = None, exclude_ids: Optional[Sequence[str]] = None,
) -> List[BaseMessage]: ) -> list[BaseMessage]:
"""Filter messages based on name, type or id. """Filter messages based on name, type or id.
Args: Args:
@ -438,7 +437,7 @@ def filter_messages(
] ]
""" # noqa: E501 """ # noqa: E501
messages = convert_to_messages(messages) messages = convert_to_messages(messages)
filtered: List[BaseMessage] = [] filtered: list[BaseMessage] = []
for msg in messages: for msg in messages:
if exclude_names and msg.name in exclude_names: if exclude_names and msg.name in exclude_names:
continue continue
@ -469,7 +468,7 @@ def merge_message_runs(
messages: Union[Iterable[MessageLikeRepresentation], PromptValue], messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
*, *,
chunk_separator: str = "\n", chunk_separator: str = "\n",
) -> List[BaseMessage]: ) -> list[BaseMessage]:
"""Merge consecutive Messages of the same type. """Merge consecutive Messages of the same type.
**NOTE**: ToolMessages are not merged, as each has a distinct tool call id that **NOTE**: ToolMessages are not merged, as each has a distinct tool call id that
@ -539,7 +538,7 @@ def merge_message_runs(
if not messages: if not messages:
return [] return []
messages = convert_to_messages(messages) messages = convert_to_messages(messages)
merged: List[BaseMessage] = [] merged: list[BaseMessage] = []
for msg in messages: for msg in messages:
curr = msg.model_copy(deep=True) curr = msg.model_copy(deep=True)
last = merged.pop() if merged else None last = merged.pop() if merged else None
@ -569,21 +568,21 @@ def trim_messages(
*, *,
max_tokens: int, max_tokens: int,
token_counter: Union[ token_counter: Union[
Callable[[List[BaseMessage]], int], Callable[[list[BaseMessage]], int],
Callable[[BaseMessage], int], Callable[[BaseMessage], int],
BaseLanguageModel, BaseLanguageModel,
], ],
strategy: Literal["first", "last"] = "last", strategy: Literal["first", "last"] = "last",
allow_partial: bool = False, allow_partial: bool = False,
end_on: Optional[ end_on: Optional[
Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]] Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]]
] = None, ] = None,
start_on: Optional[ start_on: Optional[
Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]] Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]]
] = None, ] = None,
include_system: bool = False, include_system: bool = False,
text_splitter: Optional[Union[Callable[[str], List[str]], TextSplitter]] = None, text_splitter: Optional[Union[Callable[[str], list[str]], TextSplitter]] = None,
) -> List[BaseMessage]: ) -> list[BaseMessage]:
"""Trim messages to be below a token count. """Trim messages to be below a token count.
Args: Args:
@ -875,13 +874,13 @@ def _first_max_tokens(
messages: Sequence[BaseMessage], messages: Sequence[BaseMessage],
*, *,
max_tokens: int, max_tokens: int,
token_counter: Callable[[List[BaseMessage]], int], token_counter: Callable[[list[BaseMessage]], int],
text_splitter: Callable[[str], List[str]], text_splitter: Callable[[str], list[str]],
partial_strategy: Optional[Literal["first", "last"]] = None, partial_strategy: Optional[Literal["first", "last"]] = None,
end_on: Optional[ end_on: Optional[
Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]] Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]]
] = None, ] = None,
) -> List[BaseMessage]: ) -> list[BaseMessage]:
messages = list(messages) messages = list(messages)
idx = 0 idx = 0
for i in range(len(messages)): for i in range(len(messages)):
@ -949,17 +948,17 @@ def _last_max_tokens(
messages: Sequence[BaseMessage], messages: Sequence[BaseMessage],
*, *,
max_tokens: int, max_tokens: int,
token_counter: Callable[[List[BaseMessage]], int], token_counter: Callable[[list[BaseMessage]], int],
text_splitter: Callable[[str], List[str]], text_splitter: Callable[[str], list[str]],
allow_partial: bool = False, allow_partial: bool = False,
include_system: bool = False, include_system: bool = False,
start_on: Optional[ start_on: Optional[
Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]] Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]]
] = None, ] = None,
end_on: Optional[ end_on: Optional[
Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]] Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]]
] = None, ] = None,
) -> List[BaseMessage]: ) -> list[BaseMessage]:
messages = list(messages) messages = list(messages)
if end_on: if end_on:
while messages and not _is_message_type(messages[-1], end_on): while messages and not _is_message_type(messages[-1], end_on):
@ -984,7 +983,7 @@ def _last_max_tokens(
return reversed_[::-1] return reversed_[::-1]
_MSG_CHUNK_MAP: Dict[Type[BaseMessage], Type[BaseMessageChunk]] = { _MSG_CHUNK_MAP: dict[type[BaseMessage], type[BaseMessageChunk]] = {
HumanMessage: HumanMessageChunk, HumanMessage: HumanMessageChunk,
AIMessage: AIMessageChunk, AIMessage: AIMessageChunk,
SystemMessage: SystemMessageChunk, SystemMessage: SystemMessageChunk,
@ -1024,14 +1023,14 @@ def _chunk_to_msg(chunk: BaseMessageChunk) -> BaseMessage:
) )
def _default_text_splitter(text: str) -> List[str]: def _default_text_splitter(text: str) -> list[str]:
splits = text.split("\n") splits = text.split("\n")
return [s + "\n" for s in splits[:-1]] + splits[-1:] return [s + "\n" for s in splits[:-1]] + splits[-1:]
def _is_message_type( def _is_message_type(
message: BaseMessage, message: BaseMessage,
type_: Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]], type_: Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]],
) -> bool: ) -> bool:
types = [type_] if isinstance(type_, (str, type)) else type_ types = [type_] if isinstance(type_, (str, type)) else type_
types_str = [t for t in types if isinstance(t, str)] types_str = [t for t in types if isinstance(t, str)]

View File

@ -4,11 +4,8 @@ from abc import ABC, abstractmethod
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Dict,
Generic, Generic,
List,
Optional, Optional,
Type,
TypeVar, TypeVar,
Union, Union,
) )
@ -30,7 +27,7 @@ class BaseLLMOutputParser(Generic[T], ABC):
"""Abstract base class for parsing the outputs of a model.""" """Abstract base class for parsing the outputs of a model."""
@abstractmethod @abstractmethod
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T: def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
"""Parse a list of candidate model Generations into a specific format. """Parse a list of candidate model Generations into a specific format.
Args: Args:
@ -44,7 +41,7 @@ class BaseLLMOutputParser(Generic[T], ABC):
""" """
async def aparse_result( async def aparse_result(
self, result: List[Generation], *, partial: bool = False self, result: list[Generation], *, partial: bool = False
) -> T: ) -> T:
"""Async parse a list of candidate model Generations into a specific format. """Async parse a list of candidate model Generations into a specific format.
@ -71,7 +68,7 @@ class BaseGenerationOutputParser(
return Union[str, AnyMessage] return Union[str, AnyMessage]
@property @property
def OutputType(self) -> Type[T]: def OutputType(self) -> type[T]:
"""Return the output type for the parser.""" """Return the output type for the parser."""
# even though mypy complains this isn't valid, # even though mypy complains this isn't valid,
# it is good enough for pydantic to build the schema from # it is good enough for pydantic to build the schema from
@ -156,7 +153,7 @@ class BaseOutputParser(
return Union[str, AnyMessage] return Union[str, AnyMessage]
@property @property
def OutputType(self) -> Type[T]: def OutputType(self) -> type[T]:
"""Return the output type for the parser. """Return the output type for the parser.
This property is inferred from the first type argument of the class. This property is inferred from the first type argument of the class.
@ -218,7 +215,7 @@ class BaseOutputParser(
run_type="parser", run_type="parser",
) )
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T: def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
"""Parse a list of candidate model Generations into a specific format. """Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which The return value is parsed from only the first Generation in the result, which
@ -247,7 +244,7 @@ class BaseOutputParser(
""" """
async def aparse_result( async def aparse_result(
self, result: List[Generation], *, partial: bool = False self, result: list[Generation], *, partial: bool = False
) -> T: ) -> T:
"""Async parse a list of candidate model Generations into a specific format. """Async parse a list of candidate model Generations into a specific format.
@ -305,7 +302,7 @@ class BaseOutputParser(
" This is required for serialization." " This is required for serialization."
) )
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> dict:
"""Return dictionary representation of output parser.""" """Return dictionary representation of output parser."""
output_parser_dict = super().dict(**kwargs) output_parser_dict = super().dict(**kwargs)
try: try:

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import json import json
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, List, Optional, Type, TypeVar, Union from typing import Any, Optional, TypeVar, Union
import jsonpatch # type: ignore[import] import jsonpatch # type: ignore[import]
import pydantic import pydantic
@ -42,14 +42,14 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
describing the difference between the previous and the current object. describing the difference between the previous and the current object.
""" """
pydantic_object: Annotated[Optional[Type[TBaseModel]], SkipValidation()] = None # type: ignore pydantic_object: Annotated[Optional[type[TBaseModel]], SkipValidation()] = None # type: ignore
"""The Pydantic object to use for validation. """The Pydantic object to use for validation.
If None, no validation is performed.""" If None, no validation is performed."""
def _diff(self, prev: Optional[Any], next: Any) -> Any: def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch return jsonpatch.make_patch(prev, next).patch
def _get_schema(self, pydantic_object: Type[TBaseModel]) -> dict[str, Any]: def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]:
if PYDANTIC_MAJOR_VERSION == 2: if PYDANTIC_MAJOR_VERSION == 2:
if issubclass(pydantic_object, pydantic.BaseModel): if issubclass(pydantic_object, pydantic.BaseModel):
return pydantic_object.model_json_schema() return pydantic_object.model_json_schema()
@ -57,7 +57,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
return pydantic_object.model_json_schema() return pydantic_object.model_json_schema()
return pydantic_object.model_json_schema() return pydantic_object.model_json_schema()
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object. """Parse the result of an LLM call to a JSON object.
Args: Args:

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import re import re
from abc import abstractmethod from abc import abstractmethod
from collections import deque from collections import deque
from typing import AsyncIterator, Deque, Iterator, List, TypeVar, Union from typing import AsyncIterator, Iterator, List, TypeVar, Union
from typing import Optional as Optional from typing import Optional as Optional
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
@ -22,7 +22,7 @@ def droplastn(iter: Iterator[T], n: int) -> Iterator[T]:
Yields: Yields:
The elements of the iterator, except the last n elements. The elements of the iterator, except the last n elements.
""" """
buffer: Deque[T] = deque() buffer: deque[T] = deque()
for item in iter: for item in iter:
buffer.append(item) buffer.append(item)
if len(buffer) > n: if len(buffer) > n:
@ -37,7 +37,7 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
return "list" return "list"
@abstractmethod @abstractmethod
def parse(self, text: str) -> List[str]: def parse(self, text: str) -> list[str]:
"""Parse the output of an LLM call. """Parse the output of an LLM call.
Args: Args:
@ -60,7 +60,7 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
def _transform( def _transform(
self, input: Iterator[Union[str, BaseMessage]] self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[List[str]]: ) -> Iterator[list[str]]:
buffer = "" buffer = ""
for chunk in input: for chunk in input:
if isinstance(chunk, BaseMessage): if isinstance(chunk, BaseMessage):
@ -92,7 +92,7 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
async def _atransform( async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]] self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[List[str]]: ) -> AsyncIterator[list[str]]:
buffer = "" buffer = ""
async for chunk in input: async for chunk in input:
if isinstance(chunk, BaseMessage): if isinstance(chunk, BaseMessage):
@ -136,7 +136,7 @@ class CommaSeparatedListOutputParser(ListOutputParser):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Returns: Returns:
@ -152,7 +152,7 @@ class CommaSeparatedListOutputParser(ListOutputParser):
"eg: `foo, bar, baz` or `foo,bar,baz`" "eg: `foo, bar, baz` or `foo,bar,baz`"
) )
def parse(self, text: str) -> List[str]: def parse(self, text: str) -> list[str]:
"""Parse the output of an LLM call. """Parse the output of an LLM call.
Args: Args:
@ -180,7 +180,7 @@ class NumberedListOutputParser(ListOutputParser):
"For example: \n\n1. foo\n\n2. bar\n\n3. baz" "For example: \n\n1. foo\n\n2. bar\n\n3. baz"
) )
def parse(self, text: str) -> List[str]: def parse(self, text: str) -> list[str]:
"""Parse the output of an LLM call. """Parse the output of an LLM call.
Args: Args:
@ -217,7 +217,7 @@ class MarkdownListOutputParser(ListOutputParser):
"""Return the format instructions for the Markdown list output.""" """Return the format instructions for the Markdown list output."""
return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`" return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`"
def parse(self, text: str) -> List[str]: def parse(self, text: str) -> list[str]:
"""Parse the output of an LLM call. """Parse the output of an LLM call.
Args: Args:

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import List, Literal, Union from typing import Literal, Union
from pydantic import model_validator from pydantic import model_validator
from typing_extensions import Self from typing_extensions import Self
@ -69,7 +69,7 @@ class ChatGeneration(Generation):
return self return self
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "output"] return ["langchain", "schema", "output"]
@ -86,12 +86,12 @@ class ChatGenerationChunk(ChatGeneration):
"""Type is used exclusively for serialization purposes.""" """Type is used exclusively for serialization purposes."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "output"] return ["langchain", "schema", "output"]
def __add__( def __add__(
self, other: Union[ChatGenerationChunk, List[ChatGenerationChunk]] self, other: Union[ChatGenerationChunk, list[ChatGenerationChunk]]
) -> ChatGenerationChunk: ) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk): if isinstance(other, ChatGenerationChunk):
generation_info = merge_dicts( generation_info = merge_dicts(

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Literal, Optional from typing import Any, Literal, Optional
from langchain_core.load import Serializable from langchain_core.load import Serializable
from langchain_core.utils._merge import merge_dicts from langchain_core.utils._merge import merge_dicts
@ -25,7 +25,7 @@ class Generation(Serializable):
text: str text: str
"""Generated text output.""" """Generated text output."""
generation_info: Optional[Dict[str, Any]] = None generation_info: Optional[dict[str, Any]] = None
"""Raw response from the provider. """Raw response from the provider.
May include things like the reason for finishing or token log probabilities. May include things like the reason for finishing or token log probabilities.
@ -40,7 +40,7 @@ class Generation(Serializable):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "output"] return ["langchain", "schema", "output"]
@ -49,7 +49,7 @@ class GenerationChunk(Generation):
"""Generation chunk, which can be concatenated with other Generation chunks.""" """Generation chunk, which can be concatenated with other Generation chunks."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "output"] return ["langchain", "schema", "output"]

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from copy import deepcopy from copy import deepcopy
from typing import List, Literal, Optional, Union from typing import Literal, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -18,8 +18,8 @@ class LLMResult(BaseModel):
wants to return. wants to return.
""" """
generations: List[ generations: list[
List[Union[Generation, ChatGeneration, GenerationChunk, ChatGenerationChunk]] list[Union[Generation, ChatGeneration, GenerationChunk, ChatGenerationChunk]]
] ]
"""Generated outputs. """Generated outputs.
@ -45,13 +45,13 @@ class LLMResult(BaseModel):
accessing relevant information from standardized fields present in accessing relevant information from standardized fields present in
AIMessage. AIMessage.
""" """
run: Optional[List[RunInfo]] = None run: Optional[list[RunInfo]] = None
"""List of metadata info for model call for each input.""" """List of metadata info for model call for each input."""
type: Literal["LLMResult"] = "LLMResult" # type: ignore[assignment] type: Literal["LLMResult"] = "LLMResult" # type: ignore[assignment]
"""Type is used exclusively for serialization purposes.""" """Type is used exclusively for serialization purposes."""
def flatten(self) -> List[LLMResult]: def flatten(self) -> list[LLMResult]:
"""Flatten generations into a single list. """Flatten generations into a single list.
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult

View File

@ -7,7 +7,7 @@ They can be used to represent text, images, or chat message pieces.
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Literal, Sequence, cast from typing import Literal, Sequence, cast
from typing_extensions import TypedDict from typing_extensions import TypedDict
@ -33,7 +33,7 @@ class PromptValue(Serializable, ABC):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
This is used to determine the namespace of the object when serializing. This is used to determine the namespace of the object when serializing.
Defaults to ["langchain", "schema", "prompt"]. Defaults to ["langchain", "schema", "prompt"].
@ -45,7 +45,7 @@ class PromptValue(Serializable, ABC):
"""Return prompt value as string.""" """Return prompt value as string."""
@abstractmethod @abstractmethod
def to_messages(self) -> List[BaseMessage]: def to_messages(self) -> list[BaseMessage]:
"""Return prompt as a list of Messages.""" """Return prompt as a list of Messages."""
@ -57,7 +57,7 @@ class StringPromptValue(PromptValue):
type: Literal["StringPromptValue"] = "StringPromptValue" type: Literal["StringPromptValue"] = "StringPromptValue"
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
This is used to determine the namespace of the object when serializing. This is used to determine the namespace of the object when serializing.
Defaults to ["langchain", "prompts", "base"]. Defaults to ["langchain", "prompts", "base"].
@ -68,7 +68,7 @@ class StringPromptValue(PromptValue):
"""Return prompt as string.""" """Return prompt as string."""
return self.text return self.text
def to_messages(self) -> List[BaseMessage]: def to_messages(self) -> list[BaseMessage]:
"""Return prompt as messages.""" """Return prompt as messages."""
return [HumanMessage(content=self.text)] return [HumanMessage(content=self.text)]
@ -86,12 +86,12 @@ class ChatPromptValue(PromptValue):
"""Return prompt as string.""" """Return prompt as string."""
return get_buffer_string(self.messages) return get_buffer_string(self.messages)
def to_messages(self) -> List[BaseMessage]: def to_messages(self) -> list[BaseMessage]:
"""Return prompt as a list of messages.""" """Return prompt as a list of messages."""
return list(self.messages) return list(self.messages)
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
This is used to determine the namespace of the object when serializing. This is used to determine the namespace of the object when serializing.
Defaults to ["langchain", "prompts", "chat"]. Defaults to ["langchain", "prompts", "chat"].
@ -121,7 +121,7 @@ class ImagePromptValue(PromptValue):
"""Return prompt (image URL) as string.""" """Return prompt (image URL) as string."""
return self.image_url["url"] return self.image_url["url"]
def to_messages(self) -> List[BaseMessage]: def to_messages(self) -> list[BaseMessage]:
"""Return prompt (image URL) as messages.""" """Return prompt (image URL) as messages."""
return [HumanMessage(content=[cast(dict, self.image_url)])] return [HumanMessage(content=[cast(dict, self.image_url)])]
@ -136,7 +136,7 @@ class ChatPromptValueConcrete(ChatPromptValue):
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete" type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
This is used to determine the namespace of the object when serializing. This is used to determine the namespace of the object when serializing.
Defaults to ["langchain", "prompts", "chat"]. Defaults to ["langchain", "prompts", "chat"].

View File

@ -10,10 +10,8 @@ from typing import (
Callable, Callable,
Dict, Dict,
Generic, Generic,
List,
Mapping, Mapping,
Optional, Optional,
Type,
TypeVar, TypeVar,
Union, Union,
) )
@ -45,14 +43,14 @@ class BasePromptTemplate(
): ):
"""Base class for all prompt templates, returning a prompt.""" """Base class for all prompt templates, returning a prompt."""
input_variables: List[str] input_variables: list[str]
"""A list of the names of the variables whose values are required as inputs to the """A list of the names of the variables whose values are required as inputs to the
prompt.""" prompt."""
optional_variables: List[str] = Field(default=[]) optional_variables: list[str] = Field(default=[])
"""optional_variables: A list of the names of the variables for placeholder """optional_variables: A list of the names of the variables for placeholder
or MessagePlaceholder that are optional. These variables are auto inferred or MessagePlaceholder that are optional. These variables are auto inferred
from the prompt and user need not provide them.""" from the prompt and user need not provide them."""
input_types: Dict[str, Any] = Field(default_factory=dict, exclude=True) input_types: Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006
"""A dictionary of the types of the variables the prompt template expects. """A dictionary of the types of the variables the prompt template expects.
If not provided, all variables are assumed to be strings.""" If not provided, all variables are assumed to be strings."""
output_parser: Optional[BaseOutputParser] = None output_parser: Optional[BaseOutputParser] = None
@ -62,9 +60,9 @@ class BasePromptTemplate(
Partial variables populate the template so that you don't need to Partial variables populate the template so that you don't need to
pass them in every time you call the prompt.""" pass them in every time you call the prompt."""
metadata: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None # noqa: UP006
"""Metadata to be used for tracing.""" """Metadata to be used for tracing."""
tags: Optional[List[str]] = None tags: Optional[list[str]] = None
"""Tags to be used for tracing.""" """Tags to be used for tracing."""
@model_validator(mode="after") @model_validator(mode="after")
@ -89,7 +87,7 @@ class BasePromptTemplate(
return self return self
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Returns ["langchain", "schema", "prompt_template"].""" Returns ["langchain", "schema", "prompt_template"]."""
return ["langchain", "schema", "prompt_template"] return ["langchain", "schema", "prompt_template"]
@ -115,7 +113,7 @@ class BasePromptTemplate(
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""Get the input schema for the prompt. """Get the input schema for the prompt.
Args: Args:
@ -136,7 +134,7 @@ class BasePromptTemplate(
field_definitions={**required_input_variables, **optional_input_variables}, field_definitions={**required_input_variables, **optional_input_variables},
) )
def _validate_input(self, inner_input: Any) -> Dict: def _validate_input(self, inner_input: Any) -> dict:
if not isinstance(inner_input, dict): if not isinstance(inner_input, dict):
if len(self.input_variables) == 1: if len(self.input_variables) == 1:
var_name = self.input_variables[0] var_name = self.input_variables[0]
@ -163,18 +161,18 @@ class BasePromptTemplate(
raise KeyError(msg) raise KeyError(msg)
return inner_input return inner_input
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue: def _format_prompt_with_error_handling(self, inner_input: dict) -> PromptValue:
_inner_input = self._validate_input(inner_input) _inner_input = self._validate_input(inner_input)
return self.format_prompt(**_inner_input) return self.format_prompt(**_inner_input)
async def _aformat_prompt_with_error_handling( async def _aformat_prompt_with_error_handling(
self, inner_input: Dict self, inner_input: dict
) -> PromptValue: ) -> PromptValue:
_inner_input = self._validate_input(inner_input) _inner_input = self._validate_input(inner_input)
return await self.aformat_prompt(**_inner_input) return await self.aformat_prompt(**_inner_input)
def invoke( def invoke(
self, input: Dict, config: Optional[RunnableConfig] = None self, input: dict, config: Optional[RunnableConfig] = None
) -> PromptValue: ) -> PromptValue:
"""Invoke the prompt. """Invoke the prompt.
@ -199,7 +197,7 @@ class BasePromptTemplate(
) )
async def ainvoke( async def ainvoke(
self, input: Dict, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> PromptValue: ) -> PromptValue:
"""Async invoke the prompt. """Async invoke the prompt.
@ -261,7 +259,7 @@ class BasePromptTemplate(
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs} prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict) return type(self)(**prompt_dict)
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]: def _merge_partial_and_user_variables(self, **kwargs: Any) -> dict[str, Any]:
# Get partial params: # Get partial params:
partial_kwargs = { partial_kwargs = {
k: v if not callable(v) else v() for k, v in self.partial_variables.items() k: v if not callable(v) else v() for k, v in self.partial_variables.items()
@ -307,7 +305,7 @@ class BasePromptTemplate(
"""Return the prompt type key.""" """Return the prompt type key."""
raise NotImplementedError raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> dict:
"""Return dictionary representation of prompt. """Return dictionary representation of prompt.
Args: Args:
@ -369,7 +367,7 @@ class BasePromptTemplate(
raise ValueError(f"{save_path} must be json or yaml") raise ValueError(f"{save_path} must be json or yaml")
def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> Dict: def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> dict:
base_info = {"page_content": doc.page_content, **doc.metadata} base_info = {"page_content": doc.page_content, **doc.metadata}
missing_metadata = set(prompt.input_variables).difference(base_info) missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0: if len(missing_metadata) > 0:

View File

@ -6,12 +6,10 @@ from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any, Any,
Dict,
List, List,
Literal, Literal,
Optional, Optional,
Sequence, Sequence,
Set,
Tuple, Tuple,
Type, Type,
TypedDict, TypedDict,
@ -60,12 +58,12 @@ class BaseMessagePromptTemplate(Serializable, ABC):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"] return ["langchain", "prompts", "chat"]
@abstractmethod @abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format messages from kwargs. Should return a list of BaseMessages. """Format messages from kwargs. Should return a list of BaseMessages.
Args: Args:
@ -75,7 +73,7 @@ class BaseMessagePromptTemplate(Serializable, ABC):
List of BaseMessages. List of BaseMessages.
""" """
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Async format messages from kwargs. """Async format messages from kwargs.
Should return a list of BaseMessages. Should return a list of BaseMessages.
@ -89,7 +87,7 @@ class BaseMessagePromptTemplate(Serializable, ABC):
@property @property
@abstractmethod @abstractmethod
def input_variables(self) -> List[str]: def input_variables(self) -> list[str]:
"""Input variables for this prompt template. """Input variables for this prompt template.
Returns: Returns:
@ -210,7 +208,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
Defaults to None.""" Defaults to None."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"] return ["langchain", "prompts", "chat"]
@ -223,7 +221,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
variable_name=variable_name, optional=optional, **kwargs variable_name=variable_name, optional=optional, **kwargs
) )
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format messages from kwargs. """Format messages from kwargs.
Args: Args:
@ -251,7 +249,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
return value return value
@property @property
def input_variables(self) -> List[str]: def input_variables(self) -> list[str]:
"""Input variables for this prompt template. """Input variables for this prompt template.
Returns: Returns:
@ -292,16 +290,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
"""Additional keyword arguments to pass to the prompt template.""" """Additional keyword arguments to pass to the prompt template."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"] return ["langchain", "prompts", "chat"]
@classmethod @classmethod
def from_template( def from_template(
cls: Type[MessagePromptTemplateT], cls: type[MessagePromptTemplateT],
template: str, template: str,
template_format: str = "f-string", template_format: str = "f-string",
partial_variables: Optional[Dict[str, Any]] = None, partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> MessagePromptTemplateT: ) -> MessagePromptTemplateT:
"""Create a class from a string template. """Create a class from a string template.
@ -329,9 +327,9 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
@classmethod @classmethod
def from_template_file( def from_template_file(
cls: Type[MessagePromptTemplateT], cls: type[MessagePromptTemplateT],
template_file: Union[str, Path], template_file: Union[str, Path],
input_variables: List[str], input_variables: list[str],
**kwargs: Any, **kwargs: Any,
) -> MessagePromptTemplateT: ) -> MessagePromptTemplateT:
"""Create a class from a template file. """Create a class from a template file.
@ -369,7 +367,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
""" """
return self.format(**kwargs) return self.format(**kwargs)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format messages from kwargs. """Format messages from kwargs.
Args: Args:
@ -380,7 +378,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
""" """
return [self.format(**kwargs)] return [self.format(**kwargs)]
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Async format messages from kwargs. """Async format messages from kwargs.
Args: Args:
@ -392,7 +390,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
return [await self.aformat(**kwargs)] return [await self.aformat(**kwargs)]
@property @property
def input_variables(self) -> List[str]: def input_variables(self) -> list[str]:
""" """
Input variables for this prompt template. Input variables for this prompt template.
@ -423,7 +421,7 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Role of the message.""" """Role of the message."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"] return ["langchain", "prompts", "chat"]
@ -462,37 +460,37 @@ _StringImageMessagePromptTemplateT = TypeVar(
class _TextTemplateParam(TypedDict, total=False): class _TextTemplateParam(TypedDict, total=False):
text: Union[str, Dict] text: Union[str, dict]
class _ImageTemplateParam(TypedDict, total=False): class _ImageTemplateParam(TypedDict, total=False):
image_url: Union[str, Dict] image_url: Union[str, dict]
class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user.""" """Human message prompt template. This is a message sent from the user."""
prompt: Union[ prompt: Union[
StringPromptTemplate, List[Union[StringPromptTemplate, ImagePromptTemplate]] StringPromptTemplate, list[Union[StringPromptTemplate, ImagePromptTemplate]]
] ]
"""Prompt template.""" """Prompt template."""
additional_kwargs: dict = Field(default_factory=dict) additional_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the prompt template.""" """Additional keyword arguments to pass to the prompt template."""
_msg_class: Type[BaseMessage] _msg_class: type[BaseMessage]
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"] return ["langchain", "prompts", "chat"]
@classmethod @classmethod
def from_template( def from_template(
cls: Type[_StringImageMessagePromptTemplateT], cls: type[_StringImageMessagePromptTemplateT],
template: Union[str, List[Union[str, _TextTemplateParam, _ImageTemplateParam]]], template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template_format: str = "f-string", template_format: str = "f-string",
*, *,
partial_variables: Optional[Dict[str, Any]] = None, partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> _StringImageMessagePromptTemplateT: ) -> _StringImageMessagePromptTemplateT:
"""Create a class from a string template. """Create a class from a string template.
@ -511,7 +509,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
ValueError: If the template is not a string or list of strings. ValueError: If the template is not a string or list of strings.
""" """
if isinstance(template, str): if isinstance(template, str):
prompt: Union[StringPromptTemplate, List] = PromptTemplate.from_template( prompt: Union[StringPromptTemplate, list] = PromptTemplate.from_template(
template, template,
template_format=template_format, template_format=template_format,
partial_variables=partial_variables, partial_variables=partial_variables,
@ -574,9 +572,9 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
@classmethod @classmethod
def from_template_file( def from_template_file(
cls: Type[_StringImageMessagePromptTemplateT], cls: type[_StringImageMessagePromptTemplateT],
template_file: Union[str, Path], template_file: Union[str, Path],
input_variables: List[str], input_variables: list[str],
**kwargs: Any, **kwargs: Any,
) -> _StringImageMessagePromptTemplateT: ) -> _StringImageMessagePromptTemplateT:
"""Create a class from a template file. """Create a class from a template file.
@ -593,7 +591,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
template = f.read() template = f.read()
return cls.from_template(template, input_variables=input_variables, **kwargs) return cls.from_template(template, input_variables=input_variables, **kwargs)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format messages from kwargs. """Format messages from kwargs.
Args: Args:
@ -604,7 +602,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
""" """
return [self.format(**kwargs)] return [self.format(**kwargs)]
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Async format messages from kwargs. """Async format messages from kwargs.
Args: Args:
@ -616,7 +614,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
return [await self.aformat(**kwargs)] return [await self.aformat(**kwargs)]
@property @property
def input_variables(self) -> List[str]: def input_variables(self) -> list[str]:
""" """
Input variables for this prompt template. Input variables for this prompt template.
@ -642,7 +640,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
content=text, additional_kwargs=self.additional_kwargs content=text, additional_kwargs=self.additional_kwargs
) )
else: else:
content: List = [] content: list = []
for prompt in self.prompt: for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables} inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate): if isinstance(prompt, StringPromptTemplate):
@ -670,7 +668,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
content=text, additional_kwargs=self.additional_kwargs content=text, additional_kwargs=self.additional_kwargs
) )
else: else:
content: List = [] content: list = []
for prompt in self.prompt: for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables} inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate): if isinstance(prompt, StringPromptTemplate):
@ -703,16 +701,16 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
class HumanMessagePromptTemplate(_StringImageMessagePromptTemplate): class HumanMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user.""" """Human message prompt template. This is a message sent from the user."""
_msg_class: Type[BaseMessage] = HumanMessage _msg_class: type[BaseMessage] = HumanMessage
class AIMessagePromptTemplate(_StringImageMessagePromptTemplate): class AIMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""AI message prompt template. This is a message sent from the AI.""" """AI message prompt template. This is a message sent from the AI."""
_msg_class: Type[BaseMessage] = AIMessage _msg_class: type[BaseMessage] = AIMessage
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"] return ["langchain", "prompts", "chat"]
@ -722,10 +720,10 @@ class SystemMessagePromptTemplate(_StringImageMessagePromptTemplate):
This is a message that is not sent to the user. This is a message that is not sent to the user.
""" """
_msg_class: Type[BaseMessage] = SystemMessage _msg_class: type[BaseMessage] = SystemMessage
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"] return ["langchain", "prompts", "chat"]
@ -734,7 +732,7 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
"""Base class for chat prompt templates.""" """Base class for chat prompt templates."""
@property @property
def lc_attributes(self) -> Dict: def lc_attributes(self) -> dict:
""" """
Return a list of attribute names that should be included in the Return a list of attribute names that should be included in the
serialized kwargs. These attributes must be accepted by the serialized kwargs. These attributes must be accepted by the
@ -791,10 +789,10 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
return ChatPromptValue(messages=messages) return ChatPromptValue(messages=messages)
@abstractmethod @abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format kwargs into a list of messages.""" """Format kwargs into a list of messages."""
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Async format kwargs into a list of messages.""" """Async format kwargs into a list of messages."""
return self.format_messages(**kwargs) return self.format_messages(**kwargs)
@ -935,7 +933,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
""" # noqa: E501 """ # noqa: E501
messages: Annotated[List[MessageLike], SkipValidation()] messages: Annotated[list[MessageLike], SkipValidation()]
"""List of messages consisting of either message prompt templates or messages.""" """List of messages consisting of either message prompt templates or messages."""
validate_template: bool = False validate_template: bool = False
"""Whether or not to try validating the template.""" """Whether or not to try validating the template."""
@ -999,9 +997,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
] ]
# Automatically infer input variables from messages # Automatically infer input variables from messages
input_vars: Set[str] = set() input_vars: set[str] = set()
optional_variables: Set[str] = set() optional_variables: set[str] = set()
partial_vars: Dict[str, Any] = {} partial_vars: dict[str, Any] = {}
for _message in _messages: for _message in _messages:
if isinstance(_message, MessagesPlaceholder) and _message.optional: if isinstance(_message, MessagesPlaceholder) and _message.optional:
partial_vars[_message.variable_name] = [] partial_vars[_message.variable_name] = []
@ -1022,7 +1020,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
cast(Type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs) cast(Type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs)
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"] return ["langchain", "prompts", "chat"]
@ -1071,7 +1069,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
messages = values["messages"] messages = values["messages"]
input_vars = set() input_vars = set()
optional_variables = set() optional_variables = set()
input_types: Dict[str, Any] = values.get("input_types", {}) input_types: dict[str, Any] = values.get("input_types", {})
for message in messages: for message in messages:
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)): if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
input_vars.update(message.input_variables) input_vars.update(message.input_variables)
@ -1125,7 +1123,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
@classmethod @classmethod
@deprecated("0.0.1", alternative="from_messages classmethod", pending=True) @deprecated("0.0.1", alternative="from_messages classmethod", pending=True)
def from_role_strings( def from_role_strings(
cls, string_messages: List[Tuple[str, str]] cls, string_messages: list[tuple[str, str]]
) -> ChatPromptTemplate: ) -> ChatPromptTemplate:
"""Create a chat prompt template from a list of (role, template) tuples. """Create a chat prompt template from a list of (role, template) tuples.
@ -1145,7 +1143,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
@classmethod @classmethod
@deprecated("0.0.1", alternative="from_messages classmethod", pending=True) @deprecated("0.0.1", alternative="from_messages classmethod", pending=True)
def from_strings( def from_strings(
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]] cls, string_messages: list[tuple[type[BaseMessagePromptTemplate], str]]
) -> ChatPromptTemplate: ) -> ChatPromptTemplate:
"""Create a chat prompt template from a list of (role class, template) tuples. """Create a chat prompt template from a list of (role class, template) tuples.
@ -1200,7 +1198,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
""" """
return cls(messages, template_format=template_format) return cls(messages, template_format=template_format)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format the chat template into a list of finalized messages. """Format the chat template into a list of finalized messages.
Args: Args:
@ -1224,7 +1222,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
raise ValueError(f"Unexpected input: {message_template}") raise ValueError(f"Unexpected input: {message_template}")
return result return result
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Async format the chat template into a list of finalized messages. """Async format the chat template into a list of finalized messages.
Args: Args:

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Literal, Optional, Union
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
@ -31,7 +31,7 @@ from langchain_core.prompts.string import (
class _FewShotPromptTemplateMixin(BaseModel): class _FewShotPromptTemplateMixin(BaseModel):
"""Prompt template that contains few shot examples.""" """Prompt template that contains few shot examples."""
examples: Optional[List[dict]] = None examples: Optional[list[dict]] = None
"""Examples to format into the prompt. """Examples to format into the prompt.
Either this or example_selector should be provided.""" Either this or example_selector should be provided."""
@ -46,7 +46,7 @@ class _FewShotPromptTemplateMixin(BaseModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_examples_and_selector(cls, values: Dict) -> Any: def check_examples_and_selector(cls, values: dict) -> Any:
"""Check that one and only one of examples/example_selector are provided. """Check that one and only one of examples/example_selector are provided.
Args: Args:
@ -73,7 +73,7 @@ class _FewShotPromptTemplateMixin(BaseModel):
return values return values
def _get_examples(self, **kwargs: Any) -> List[dict]: def _get_examples(self, **kwargs: Any) -> list[dict]:
"""Get the examples to use for formatting the prompt. """Get the examples to use for formatting the prompt.
Args: Args:
@ -94,7 +94,7 @@ class _FewShotPromptTemplateMixin(BaseModel):
"One of 'examples' and 'example_selector' should be provided" "One of 'examples' and 'example_selector' should be provided"
) )
async def _aget_examples(self, **kwargs: Any) -> List[dict]: async def _aget_examples(self, **kwargs: Any) -> list[dict]:
"""Async get the examples to use for formatting the prompt. """Async get the examples to use for formatting the prompt.
Args: Args:
@ -363,7 +363,7 @@ class FewShotChatMessagePromptTemplate(
chain.invoke({"input": "What's 3+3?"}) chain.invoke({"input": "What's 3+3?"})
""" """
input_variables: List[str] = Field(default_factory=list) input_variables: list[str] = Field(default_factory=list)
"""A list of the names of the variables the prompt template will use """A list of the names of the variables the prompt template will use
to pass to the example_selector, if provided.""" to pass to the example_selector, if provided."""
@ -380,7 +380,7 @@ class FewShotChatMessagePromptTemplate(
extra="forbid", extra="forbid",
) )
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format kwargs into a list of messages. """Format kwargs into a list of messages.
Args: Args:
@ -402,7 +402,7 @@ class FewShotChatMessagePromptTemplate(
] ]
return messages return messages
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Async format kwargs into a list of messages. """Async format kwargs into a list of messages.
Args: Args:

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Literal, Optional, Union
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
@ -54,13 +54,13 @@ class PromptTemplate(StringPromptTemplate):
""" """
@property @property
def lc_attributes(self) -> Dict[str, Any]: def lc_attributes(self) -> dict[str, Any]:
return { return {
"template_format": self.template_format, "template_format": self.template_format,
} }
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "prompts", "prompt"] return ["langchain", "prompts", "prompt"]
@ -76,7 +76,7 @@ class PromptTemplate(StringPromptTemplate):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def pre_init_validation(cls, values: Dict) -> Any: def pre_init_validation(cls, values: dict) -> Any:
"""Check that template and input variables are consistent.""" """Check that template and input variables are consistent."""
if values.get("template") is None: if values.get("template") is None:
# Will let pydantic fail with a ValidationError if template # Will let pydantic fail with a ValidationError if template
@ -183,9 +183,9 @@ class PromptTemplate(StringPromptTemplate):
@classmethod @classmethod
def from_examples( def from_examples(
cls, cls,
examples: List[str], examples: list[str],
suffix: str, suffix: str,
input_variables: List[str], input_variables: list[str],
example_separator: str = "\n\n", example_separator: str = "\n\n",
prefix: str = "", prefix: str = "",
**kwargs: Any, **kwargs: Any,
@ -215,7 +215,7 @@ class PromptTemplate(StringPromptTemplate):
def from_file( def from_file(
cls, cls,
template_file: Union[str, Path], template_file: Union[str, Path],
input_variables: Optional[List[str]] = None, input_variables: Optional[list[str]] = None,
encoding: Optional[str] = None, encoding: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> PromptTemplate: ) -> PromptTemplate:
@ -249,7 +249,7 @@ class PromptTemplate(StringPromptTemplate):
template: str, template: str,
*, *,
template_format: str = "f-string", template_format: str = "f-string",
partial_variables: Optional[Dict[str, Any]] = None, partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> PromptTemplate: ) -> PromptTemplate:
"""Load a prompt template from a template. """Load a prompt template from a template.

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import warnings import warnings
from abc import ABC from abc import ABC
from string import Formatter from string import Formatter
from typing import Any, Callable, Dict, List, Set, Tuple, Type from typing import Any, Callable, Dict
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
@ -60,7 +60,7 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
return SandboxedEnvironment().from_string(template).render(**kwargs) return SandboxedEnvironment().from_string(template).render(**kwargs)
def validate_jinja2(template: str, input_variables: List[str]) -> None: def validate_jinja2(template: str, input_variables: list[str]) -> None:
""" """
Validate that the input variables are valid for the template. Validate that the input variables are valid for the template.
Issues a warning if missing or extra variables are found. Issues a warning if missing or extra variables are found.
@ -85,7 +85,7 @@ def validate_jinja2(template: str, input_variables: List[str]) -> None:
warnings.warn(warning_message.strip(), stacklevel=7) warnings.warn(warning_message.strip(), stacklevel=7)
def _get_jinja2_variables_from_template(template: str) -> Set[str]: def _get_jinja2_variables_from_template(template: str) -> set[str]:
try: try:
from jinja2 import Environment, meta from jinja2 import Environment, meta
except ImportError as e: except ImportError as e:
@ -114,7 +114,7 @@ def mustache_formatter(template: str, **kwargs: Any) -> str:
def mustache_template_vars( def mustache_template_vars(
template: str, template: str,
) -> Set[str]: ) -> set[str]:
"""Get the variables from a mustache template. """Get the variables from a mustache template.
Args: Args:
@ -123,7 +123,7 @@ def mustache_template_vars(
Returns: Returns:
The variables from the template. The variables from the template.
""" """
vars: Set[str] = set() vars: set[str] = set()
section_depth = 0 section_depth = 0
for type, key in mustache.tokenize(template): for type, key in mustache.tokenize(template):
if type == "end": if type == "end":
@ -144,7 +144,7 @@ Defs = Dict[str, "Defs"]
def mustache_schema( def mustache_schema(
template: str, template: str,
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""Get the variables from a mustache template. """Get the variables from a mustache template.
Args: Args:
@ -154,8 +154,8 @@ def mustache_schema(
The variables from the template as a Pydantic model. The variables from the template as a Pydantic model.
""" """
fields = {} fields = {}
prefix: Tuple[str, ...] = () prefix: tuple[str, ...] = ()
section_stack: List[Tuple[str, ...]] = [] section_stack: list[tuple[str, ...]] = []
for type, key in mustache.tokenize(template): for type, key in mustache.tokenize(template):
if key == ".": if key == ".":
continue continue
@ -178,7 +178,7 @@ def mustache_schema(
return _create_model_recursive("PromptInput", defs) return _create_model_recursive("PromptInput", defs)
def _create_model_recursive(name: str, defs: Defs) -> Type: def _create_model_recursive(name: str, defs: Defs) -> type:
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
name, name,
**{ **{
@ -188,20 +188,20 @@ def _create_model_recursive(name: str, defs: Defs) -> Type:
) )
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { DEFAULT_FORMATTER_MAPPING: dict[str, Callable] = {
"f-string": formatter.format, "f-string": formatter.format,
"mustache": mustache_formatter, "mustache": mustache_formatter,
"jinja2": jinja2_formatter, "jinja2": jinja2_formatter,
} }
DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = { DEFAULT_VALIDATOR_MAPPING: dict[str, Callable] = {
"f-string": formatter.validate_input_variables, "f-string": formatter.validate_input_variables,
"jinja2": validate_jinja2, "jinja2": validate_jinja2,
} }
def check_valid_template( def check_valid_template(
template: str, template_format: str, input_variables: List[str] template: str, template_format: str, input_variables: list[str]
) -> None: ) -> None:
"""Check that template string is valid. """Check that template string is valid.
@ -230,7 +230,7 @@ def check_valid_template(
) from exc ) from exc
def get_template_variables(template: str, template_format: str) -> List[str]: def get_template_variables(template: str, template_format: str) -> list[str]:
"""Get the variables from the template. """Get the variables from the template.
Args: Args:
@ -262,7 +262,7 @@ class StringPromptTemplate(BasePromptTemplate, ABC):
"""String prompt that exposes the format method, returning a prompt.""" """String prompt that exposes the format method, returning a prompt."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "prompts", "base"] return ["langchain", "prompts", "base"]

View File

@ -24,7 +24,7 @@ from __future__ import annotations
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from inspect import signature from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, List, Optional
from pydantic import ConfigDict from pydantic import ConfigDict
from typing_extensions import TypedDict from typing_extensions import TypedDict
@ -132,14 +132,14 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
_new_arg_supported: bool = False _new_arg_supported: bool = False
_expects_other_args: bool = False _expects_other_args: bool = False
tags: Optional[List[str]] = None tags: Optional[list[str]] = None
"""Optional list of tags associated with the retriever. Defaults to None. """Optional list of tags associated with the retriever. Defaults to None.
These tags will be associated with each call to this retriever, These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`. and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a retriever with its You can use these to eg identify a specific instance of a retriever with its
use case. use case.
""" """
metadata: Optional[Dict[str, Any]] = None metadata: Optional[dict[str, Any]] = None
"""Optional metadata associated with the retriever. Defaults to None. """Optional metadata associated with the retriever. Defaults to None.
This metadata will be associated with each call to this retriever, This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`. and passed as arguments to the handlers defined in `callbacks`.
@ -200,7 +200,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
def invoke( def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Document]: ) -> list[Document]:
"""Invoke the retriever to get relevant documents. """Invoke the retriever to get relevant documents.
Main entry point for synchronous retriever invocations. Main entry point for synchronous retriever invocations.
@ -263,7 +263,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
input: str, input: str,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
"""Asynchronously invoke the retriever to get relevant documents. """Asynchronously invoke the retriever to get relevant documents.
Main entry point for asynchronous retriever invocations. Main entry point for asynchronous retriever invocations.
@ -324,7 +324,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
@abstractmethod @abstractmethod
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]: ) -> list[Document]:
"""Get documents relevant to a query. """Get documents relevant to a query.
Args: Args:
@ -336,7 +336,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]: ) -> list[Document]:
"""Asynchronously get documents relevant to a query. """Asynchronously get documents relevant to a query.
Args: Args:
@ -358,11 +358,11 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
query: str, query: str,
*, *,
callbacks: Callbacks = None, callbacks: Callbacks = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
"""Retrieve documents relevant to a query. """Retrieve documents relevant to a query.
Users should favor using `.invoke` or `.batch` rather than Users should favor using `.invoke` or `.batch` rather than
@ -402,11 +402,11 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
query: str, query: str,
*, *,
callbacks: Callbacks = None, callbacks: Callbacks = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
"""Asynchronously get documents relevant to a query. """Asynchronously get documents relevant to a query.
Users should favor using `.ainvoke` or `.abatch` rather than Users should favor using `.ainvoke` or `.abatch` rather than

View File

@ -27,8 +27,6 @@ from typing import (
Optional, Optional,
Protocol, Protocol,
Sequence, Sequence,
Set,
Tuple,
Type, Type,
TypeVar, TypeVar,
Union, Union,
@ -273,7 +271,7 @@ class Runnable(Generic[Input, Output], ABC):
return name_ return name_
@property @property
def InputType(self) -> Type[Input]: def InputType(self) -> type[Input]:
"""The type of input this Runnable accepts specified as a type annotation.""" """The type of input this Runnable accepts specified as a type annotation."""
# First loop through all parent classes and if any of them is # First loop through all parent classes and if any of them is
# a pydantic model, we will pick up the generic parameterization # a pydantic model, we will pick up the generic parameterization
@ -298,7 +296,7 @@ class Runnable(Generic[Input, Output], ABC):
) )
@property @property
def OutputType(self) -> Type[Output]: def OutputType(self) -> type[Output]:
"""The type of output this Runnable produces specified as a type annotation.""" """The type of output this Runnable produces specified as a type annotation."""
# First loop through bases -- this will help generic # First loop through bases -- this will help generic
# any pydantic models. # any pydantic models.
@ -319,13 +317,13 @@ class Runnable(Generic[Input, Output], ABC):
) )
@property @property
def input_schema(self) -> Type[BaseModel]: def input_schema(self) -> type[BaseModel]:
"""The type of input this Runnable accepts specified as a pydantic model.""" """The type of input this Runnable accepts specified as a pydantic model."""
return self.get_input_schema() return self.get_input_schema()
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""Get a pydantic model that can be used to validate input to the Runnable. """Get a pydantic model that can be used to validate input to the Runnable.
Runnables that leverage the configurable_fields and configurable_alternatives Runnables that leverage the configurable_fields and configurable_alternatives
@ -360,7 +358,7 @@ class Runnable(Generic[Input, Output], ABC):
def get_input_jsonschema( def get_input_jsonschema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Get a JSON schema that represents the input to the Runnable. """Get a JSON schema that represents the input to the Runnable.
Args: Args:
@ -387,13 +385,13 @@ class Runnable(Generic[Input, Output], ABC):
return self.get_input_schema(config).model_json_schema() return self.get_input_schema(config).model_json_schema()
@property @property
def output_schema(self) -> Type[BaseModel]: def output_schema(self) -> type[BaseModel]:
"""The type of output this Runnable produces specified as a pydantic model.""" """The type of output this Runnable produces specified as a pydantic model."""
return self.get_output_schema() return self.get_output_schema()
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""Get a pydantic model that can be used to validate output to the Runnable. """Get a pydantic model that can be used to validate output to the Runnable.
Runnables that leverage the configurable_fields and configurable_alternatives Runnables that leverage the configurable_fields and configurable_alternatives
@ -428,7 +426,7 @@ class Runnable(Generic[Input, Output], ABC):
def get_output_jsonschema( def get_output_jsonschema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Get a JSON schema that represents the output of the Runnable. """Get a JSON schema that represents the output of the Runnable.
Args: Args:
@ -455,13 +453,13 @@ class Runnable(Generic[Input, Output], ABC):
return self.get_output_schema(config).model_json_schema() return self.get_output_schema(config).model_json_schema()
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> list[ConfigurableFieldSpec]:
"""List configurable fields for this Runnable.""" """List configurable fields for this Runnable."""
return [] return []
def config_schema( def config_schema(
self, *, include: Optional[Sequence[str]] = None self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""The type of config this Runnable accepts specified as a pydantic model. """The type of config this Runnable accepts specified as a pydantic model.
To mark a field as configurable, see the `configurable_fields` To mark a field as configurable, see the `configurable_fields`
@ -509,7 +507,7 @@ class Runnable(Generic[Input, Output], ABC):
def get_config_jsonschema( def get_config_jsonschema(
self, *, include: Optional[Sequence[str]] = None self, *, include: Optional[Sequence[str]] = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Get a JSON schema that represents the output of the Runnable. """Get a JSON schema that represents the output of the Runnable.
Args: Args:
@ -544,7 +542,7 @@ class Runnable(Generic[Input, Output], ABC):
def get_prompts( def get_prompts(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> List[BasePromptTemplate]: ) -> list[BasePromptTemplate]:
"""Return a list of prompts used by this Runnable.""" """Return a list of prompts used by this Runnable."""
from langchain_core.prompts.base import BasePromptTemplate from langchain_core.prompts.base import BasePromptTemplate
@ -614,7 +612,7 @@ class Runnable(Generic[Input, Output], ABC):
""" """
return RunnableSequence(self, *others, name=name) return RunnableSequence(self, *others, name=name)
def pick(self, keys: Union[str, List[str]]) -> RunnableSerializable[Any, Any]: def pick(self, keys: Union[str, list[str]]) -> RunnableSerializable[Any, Any]:
"""Pick keys from the dict output of this Runnable. """Pick keys from the dict output of this Runnable.
Pick single key: Pick single key:
@ -670,11 +668,11 @@ class Runnable(Generic[Input, Output], ABC):
def assign( def assign(
self, self,
**kwargs: Union[ **kwargs: Union[
Runnable[Dict[str, Any], Any], Runnable[dict[str, Any], Any],
Callable[[Dict[str, Any]], Any], Callable[[dict[str, Any]], Any],
Mapping[ Mapping[
str, str,
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]], Union[Runnable[dict[str, Any], Any], Callable[[dict[str, Any]], Any]],
], ],
], ],
) -> RunnableSerializable[Any, Any]: ) -> RunnableSerializable[Any, Any]:
@ -710,7 +708,7 @@ class Runnable(Generic[Input, Output], ABC):
""" """
from langchain_core.runnables.passthrough import RunnableAssign from langchain_core.runnables.passthrough import RunnableAssign
return self | RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs)) return self | RunnableAssign(RunnableParallel[dict[str, Any]](kwargs))
""" --- Public API --- """ """ --- Public API --- """
@ -744,12 +742,12 @@ class Runnable(Generic[Input, Output], ABC):
def batch( def batch(
self, self,
inputs: List[Input], inputs: list[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> list[Output]:
"""Default implementation runs invoke in parallel using a thread pool executor. """Default implementation runs invoke in parallel using a thread pool executor.
The default implementation of batch works well for IO bound runnables. The default implementation of batch works well for IO bound runnables.
@ -786,7 +784,7 @@ class Runnable(Generic[Input, Output], ABC):
*, *,
return_exceptions: Literal[False] = False, return_exceptions: Literal[False] = False,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Tuple[int, Output]]: ... ) -> Iterator[tuple[int, Output]]: ...
@overload @overload
def batch_as_completed( def batch_as_completed(
@ -796,7 +794,7 @@ class Runnable(Generic[Input, Output], ABC):
*, *,
return_exceptions: Literal[True], return_exceptions: Literal[True],
**kwargs: Any, **kwargs: Any,
) -> Iterator[Tuple[int, Union[Output, Exception]]]: ... ) -> Iterator[tuple[int, Union[Output, Exception]]]: ...
def batch_as_completed( def batch_as_completed(
self, self,
@ -805,7 +803,7 @@ class Runnable(Generic[Input, Output], ABC):
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Iterator[Tuple[int, Union[Output, Exception]]]: ) -> Iterator[tuple[int, Union[Output, Exception]]]:
"""Run invoke in parallel on a list of inputs, """Run invoke in parallel on a list of inputs,
yielding results as they complete.""" yielding results as they complete."""
@ -816,7 +814,7 @@ class Runnable(Generic[Input, Output], ABC):
def invoke( def invoke(
i: int, input: Input, config: RunnableConfig i: int, input: Input, config: RunnableConfig
) -> Tuple[int, Union[Output, Exception]]: ) -> tuple[int, Union[Output, Exception]]:
if return_exceptions: if return_exceptions:
try: try:
out: Union[Output, Exception] = self.invoke(input, config, **kwargs) out: Union[Output, Exception] = self.invoke(input, config, **kwargs)
@ -848,12 +846,12 @@ class Runnable(Generic[Input, Output], ABC):
async def abatch( async def abatch(
self, self,
inputs: List[Input], inputs: list[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> list[Output]:
"""Default implementation runs ainvoke in parallel using asyncio.gather. """Default implementation runs ainvoke in parallel using asyncio.gather.
The default implementation of batch works well for IO bound runnables. The default implementation of batch works well for IO bound runnables.
@ -902,7 +900,7 @@ class Runnable(Generic[Input, Output], ABC):
*, *,
return_exceptions: Literal[False] = False, return_exceptions: Literal[False] = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Output]]: ... ) -> AsyncIterator[tuple[int, Output]]: ...
@overload @overload
def abatch_as_completed( def abatch_as_completed(
@ -912,7 +910,7 @@ class Runnable(Generic[Input, Output], ABC):
*, *,
return_exceptions: Literal[True], return_exceptions: Literal[True],
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: ... ) -> AsyncIterator[tuple[int, Union[Output, Exception]]]: ...
async def abatch_as_completed( async def abatch_as_completed(
self, self,
@ -921,7 +919,7 @@ class Runnable(Generic[Input, Output], ABC):
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: ) -> AsyncIterator[tuple[int, Union[Output, Exception]]]:
"""Run ainvoke in parallel on a list of inputs, """Run ainvoke in parallel on a list of inputs,
yielding results as they complete. yielding results as they complete.
@ -947,7 +945,7 @@ class Runnable(Generic[Input, Output], ABC):
async def ainvoke( async def ainvoke(
i: int, input: Input, config: RunnableConfig i: int, input: Input, config: RunnableConfig
) -> Tuple[int, Union[Output, Exception]]: ) -> tuple[int, Union[Output, Exception]]:
if return_exceptions: if return_exceptions:
try: try:
out: Union[Output, Exception] = await self.ainvoke( out: Union[Output, Exception] = await self.ainvoke(
@ -1699,8 +1697,8 @@ class Runnable(Generic[Input, Output], ABC):
def with_types( def with_types(
self, self,
*, *,
input_type: Optional[Type[Input]] = None, input_type: Optional[type[Input]] = None,
output_type: Optional[Type[Output]] = None, output_type: Optional[type[Output]] = None,
) -> Runnable[Input, Output]: ) -> Runnable[Input, Output]:
""" """
Bind input and output types to a Runnable, returning a new Runnable. Bind input and output types to a Runnable, returning a new Runnable.
@ -1722,7 +1720,7 @@ class Runnable(Generic[Input, Output], ABC):
def with_retry( def with_retry(
self, self,
*, *,
retry_if_exception_type: Tuple[Type[BaseException], ...] = (Exception,), retry_if_exception_type: tuple[type[BaseException], ...] = (Exception,),
wait_exponential_jitter: bool = True, wait_exponential_jitter: bool = True,
stop_after_attempt: int = 3, stop_after_attempt: int = 3,
) -> Runnable[Input, Output]: ) -> Runnable[Input, Output]:
@ -1789,7 +1787,7 @@ class Runnable(Generic[Input, Output], ABC):
max_attempt_number=stop_after_attempt, max_attempt_number=stop_after_attempt,
) )
def map(self) -> Runnable[List[Input], List[Output]]: def map(self) -> Runnable[list[Input], list[Output]]:
""" """
Return a new Runnable that maps a list of inputs to a list of outputs, Return a new Runnable that maps a list of inputs to a list of outputs,
by calling invoke() with each input. by calling invoke() with each input.
@ -1815,7 +1813,7 @@ class Runnable(Generic[Input, Output], ABC):
self, self,
fallbacks: Sequence[Runnable[Input, Output]], fallbacks: Sequence[Runnable[Input, Output]],
*, *,
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,), exceptions_to_handle: tuple[type[BaseException], ...] = (Exception,),
exception_key: Optional[str] = None, exception_key: Optional[str] = None,
) -> RunnableWithFallbacksT[Input, Output]: ) -> RunnableWithFallbacksT[Input, Output]:
"""Add fallbacks to a Runnable, returning a new Runnable. """Add fallbacks to a Runnable, returning a new Runnable.
@ -1893,7 +1891,7 @@ class Runnable(Generic[Input, Output], ABC):
input: Input, input: Input,
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
run_type: Optional[str] = None, run_type: Optional[str] = None,
serialized: Optional[Dict[str, Any]] = None, serialized: Optional[dict[str, Any]] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
@ -1942,7 +1940,7 @@ class Runnable(Generic[Input, Output], ABC):
input: Input, input: Input,
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
run_type: Optional[str] = None, run_type: Optional[str] = None,
serialized: Optional[Dict[str, Any]] = None, serialized: Optional[dict[str, Any]] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
@ -1977,23 +1975,23 @@ class Runnable(Generic[Input, Output], ABC):
def _batch_with_config( def _batch_with_config(
self, self,
func: Union[ func: Union[
Callable[[List[Input]], List[Union[Exception, Output]]], Callable[[list[Input]], list[Union[Exception, Output]]],
Callable[ Callable[
[List[Input], List[CallbackManagerForChainRun]], [list[Input], list[CallbackManagerForChainRun]],
List[Union[Exception, Output]], list[Union[Exception, Output]],
], ],
Callable[ Callable[
[List[Input], List[CallbackManagerForChainRun], List[RunnableConfig]], [list[Input], list[CallbackManagerForChainRun], list[RunnableConfig]],
List[Union[Exception, Output]], list[Union[Exception, Output]],
], ],
], ],
input: List[Input], input: list[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
run_type: Optional[str] = None, run_type: Optional[str] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> list[Output]:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.""" with callbacks. Use this method to implement invoke() in subclasses."""
if not input: if not input:
@ -2045,27 +2043,27 @@ class Runnable(Generic[Input, Output], ABC):
async def _abatch_with_config( async def _abatch_with_config(
self, self,
func: Union[ func: Union[
Callable[[List[Input]], Awaitable[List[Union[Exception, Output]]]], Callable[[list[Input]], Awaitable[list[Union[Exception, Output]]]],
Callable[ Callable[
[List[Input], List[AsyncCallbackManagerForChainRun]], [list[Input], list[AsyncCallbackManagerForChainRun]],
Awaitable[List[Union[Exception, Output]]], Awaitable[list[Union[Exception, Output]]],
], ],
Callable[ Callable[
[ [
List[Input], list[Input],
List[AsyncCallbackManagerForChainRun], list[AsyncCallbackManagerForChainRun],
List[RunnableConfig], list[RunnableConfig],
], ],
Awaitable[List[Union[Exception, Output]]], Awaitable[list[Union[Exception, Output]]],
], ],
], ],
input: List[Input], input: list[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
run_type: Optional[str] = None, run_type: Optional[str] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> list[Output]:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.""" with callbacks. Use this method to implement invoke() in subclasses."""
if not input: if not input:
@ -2073,7 +2071,7 @@ class Runnable(Generic[Input, Output], ABC):
configs = get_config_list(config, len(input)) configs = get_config_list(config, len(input))
callback_managers = [get_async_callback_manager_for_config(c) for c in configs] callback_managers = [get_async_callback_manager_for_config(c) for c in configs]
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*( *(
callback_manager.on_chain_start( callback_manager.on_chain_start(
None, None,
@ -2106,7 +2104,7 @@ class Runnable(Generic[Input, Output], ABC):
raise raise
else: else:
first_exception: Optional[Exception] = None first_exception: Optional[Exception] = None
coros: List[Awaitable[None]] = [] coros: list[Awaitable[None]] = []
for run_manager, out in zip(run_managers, output): for run_manager, out in zip(run_managers, output):
if isinstance(out, Exception): if isinstance(out, Exception):
first_exception = first_exception or out first_exception = first_exception or out
@ -2333,11 +2331,11 @@ class Runnable(Generic[Input, Output], ABC):
@beta_decorator.beta(message="This API is in beta and may change in the future.") @beta_decorator.beta(message="This API is in beta and may change in the future.")
def as_tool( def as_tool(
self, self,
args_schema: Optional[Type[BaseModel]] = None, args_schema: Optional[type[BaseModel]] = None,
*, *,
name: Optional[str] = None, name: Optional[str] = None,
description: Optional[str] = None, description: Optional[str] = None,
arg_types: Optional[Dict[str, Type]] = None, arg_types: Optional[dict[str, type]] = None,
) -> BaseTool: ) -> BaseTool:
"""Create a BaseTool from a Runnable. """Create a BaseTool from a Runnable.
@ -2573,8 +2571,8 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
def _seq_input_schema( def _seq_input_schema(
steps: List[Runnable[Any, Any]], config: Optional[RunnableConfig] steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig]
) -> Type[BaseModel]: ) -> type[BaseModel]:
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick
first = steps[0] first = steps[0]
@ -2599,8 +2597,8 @@ def _seq_input_schema(
def _seq_output_schema( def _seq_output_schema(
steps: List[Runnable[Any, Any]], config: Optional[RunnableConfig] steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig]
) -> Type[BaseModel]: ) -> type[BaseModel]:
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick
last = steps[-1] last = steps[-1]
@ -2730,7 +2728,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
# the last type. # the last type.
first: Runnable[Input, Any] first: Runnable[Input, Any]
"""The first Runnable in the sequence.""" """The first Runnable in the sequence."""
middle: List[Runnable[Any, Any]] = Field(default_factory=list) middle: list[Runnable[Any, Any]] = Field(default_factory=list)
"""The middle Runnables in the sequence.""" """The middle Runnables in the sequence."""
last: Runnable[Any, Output] last: Runnable[Any, Output]
"""The last Runnable in the sequence.""" """The last Runnable in the sequence."""
@ -2740,7 +2738,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
*steps: RunnableLike, *steps: RunnableLike,
name: Optional[str] = None, name: Optional[str] = None,
first: Optional[Runnable[Any, Any]] = None, first: Optional[Runnable[Any, Any]] = None,
middle: Optional[List[Runnable[Any, Any]]] = None, middle: Optional[list[Runnable[Any, Any]]] = None,
last: Optional[Runnable[Any, Any]] = None, last: Optional[Runnable[Any, Any]] = None,
) -> None: ) -> None:
"""Create a new RunnableSequence. """Create a new RunnableSequence.
@ -2755,7 +2753,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
Raises: Raises:
ValueError: If the sequence has less than 2 steps. ValueError: If the sequence has less than 2 steps.
""" """
steps_flat: List[Runnable] = [] steps_flat: list[Runnable] = []
if not steps: if not steps:
if first is not None and last is not None: if first is not None and last is not None:
steps_flat = [first] + (middle or []) + [last] steps_flat = [first] + (middle or []) + [last]
@ -2776,12 +2774,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
) )
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
@property @property
def steps(self) -> List[Runnable[Any, Any]]: def steps(self) -> list[Runnable[Any, Any]]:
"""All the Runnables that make up the sequence in order. """All the Runnables that make up the sequence in order.
Returns: Returns:
@ -2804,18 +2802,18 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
) )
@property @property
def InputType(self) -> Type[Input]: def InputType(self) -> type[Input]:
"""The type of the input to the Runnable.""" """The type of the input to the Runnable."""
return self.first.InputType return self.first.InputType
@property @property
def OutputType(self) -> Type[Output]: def OutputType(self) -> type[Output]:
"""The type of the output of the Runnable.""" """The type of the output of the Runnable."""
return self.last.OutputType return self.last.OutputType
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""Get the input schema of the Runnable. """Get the input schema of the Runnable.
Args: Args:
@ -2828,7 +2826,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""Get the output schema of the Runnable. """Get the output schema of the Runnable.
Args: Args:
@ -2840,7 +2838,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
return _seq_output_schema(self.steps, config) return _seq_output_schema(self.steps, config)
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> list[ConfigurableFieldSpec]:
"""Get the config specs of the Runnable. """Get the config specs of the Runnable.
Returns: Returns:
@ -2862,8 +2860,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
[tup for tup in all_specs if tup[0].id.startswith(CONTEXT_CONFIG_PREFIX)], [tup for tup in all_specs if tup[0].id.startswith(CONTEXT_CONFIG_PREFIX)],
lambda x: x[1], lambda x: x[1],
) )
next_deps: Set[str] = set() next_deps: set[str] = set()
deps_by_pos: Dict[int, Set[str]] = {} deps_by_pos: dict[int, set[str]] = {}
for pos, specs in specs_by_pos: for pos, specs in specs_by_pos:
deps_by_pos[pos] = next_deps deps_by_pos[pos] = next_deps
next_deps = next_deps | {spec[0].id for spec in specs} next_deps = next_deps | {spec[0].id for spec in specs}
@ -3065,12 +3063,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
def batch( def batch(
self, self,
inputs: List[Input], inputs: list[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> list[Output]:
from langchain_core.beta.runnables.context import config_with_context from langchain_core.beta.runnables.context import config_with_context
from langchain_core.callbacks.manager import CallbackManager from langchain_core.callbacks.manager import CallbackManager
@ -3111,7 +3109,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
# Track which inputs (by index) failed so far # Track which inputs (by index) failed so far
# If an input has failed it will be present in this map, # If an input has failed it will be present in this map,
# and the value will be the exception that was raised. # and the value will be the exception that was raised.
failed_inputs_map: Dict[int, Exception] = {} failed_inputs_map: dict[int, Exception] = {}
for stepidx, step in enumerate(self.steps): for stepidx, step in enumerate(self.steps):
# Assemble the original indexes of the remaining inputs # Assemble the original indexes of the remaining inputs
# (i.e. the ones that haven't failed yet) # (i.e. the ones that haven't failed yet)
@ -3191,12 +3189,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
async def abatch( async def abatch(
self, self,
inputs: List[Input], inputs: list[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> list[Output]:
from langchain_core.beta.runnables.context import aconfig_with_context from langchain_core.beta.runnables.context import aconfig_with_context
from langchain_core.callbacks.manager import AsyncCallbackManager from langchain_core.callbacks.manager import AsyncCallbackManager
@ -3221,7 +3219,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
for config in configs for config in configs
] ]
# start the root runs, one per input # start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*( *(
cm.on_chain_start( cm.on_chain_start(
None, None,
@ -3240,7 +3238,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
# Track which inputs (by index) failed so far # Track which inputs (by index) failed so far
# If an input has failed it will be present in this map, # If an input has failed it will be present in this map,
# and the value will be the exception that was raised. # and the value will be the exception that was raised.
failed_inputs_map: Dict[int, Exception] = {} failed_inputs_map: dict[int, Exception] = {}
for stepidx, step in enumerate(self.steps): for stepidx, step in enumerate(self.steps):
# Assemble the original indexes of the remaining inputs # Assemble the original indexes of the remaining inputs
# (i.e. the ones that haven't failed yet) # (i.e. the ones that haven't failed yet)
@ -3305,7 +3303,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
raise raise
else: else:
first_exception: Optional[Exception] = None first_exception: Optional[Exception] = None
coros: List[Awaitable[None]] = [] coros: list[Awaitable[None]] = []
for run_manager, out in zip(run_managers, inputs): for run_manager, out in zip(run_managers, inputs):
if isinstance(out, Exception): if isinstance(out, Exception):
first_exception = first_exception or out first_exception = first_exception or out
@ -3533,7 +3531,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
@ -3567,7 +3565,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""Get the input schema of the Runnable. """Get the input schema of the Runnable.
Args: Args:
@ -3596,7 +3594,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""Get the output schema of the Runnable. """Get the output schema of the Runnable.
Args: Args:
@ -3609,7 +3607,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
return create_model_v2(self.get_name("Output"), field_definitions=fields) return create_model_v2(self.get_name("Output"), field_definitions=fields)
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> list[ConfigurableFieldSpec]:
"""Get the config specs of the Runnable. """Get the config specs of the Runnable.
Returns: Returns:
@ -3662,7 +3660,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
from langchain_core.callbacks.manager import CallbackManager from langchain_core.callbacks.manager import CallbackManager
# setup callbacks # setup callbacks
@ -3724,7 +3722,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
input: Input, input: Input,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
# setup callbacks # setup callbacks
config = ensure_config(config) config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
@ -3829,7 +3827,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
input: Iterator[Input], input: Iterator[Input],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Dict[str, Any]]: ) -> Iterator[dict[str, Any]]:
yield from self._transform_stream_with_config( yield from self._transform_stream_with_config(
input, self._transform, config, **kwargs input, self._transform, config, **kwargs
) )
@ -3839,7 +3837,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
input: Input, input: Input,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Iterator[Dict[str, Any]]: ) -> Iterator[dict[str, Any]]:
yield from self.transform(iter([input]), config) yield from self.transform(iter([input]), config)
async def _atransform( async def _atransform(
@ -3898,7 +3896,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
input: AsyncIterator[Input], input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]: ) -> AsyncIterator[dict[str, Any]]:
async for chunk in self._atransform_stream_with_config( async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, **kwargs input, self._atransform, config, **kwargs
): ):
@ -3909,7 +3907,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
input: Input, input: Input,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Dict[str, Any]]: ) -> AsyncIterator[dict[str, Any]]:
async def input_aiter() -> AsyncIterator[Input]: async def input_aiter() -> AsyncIterator[Input]:
yield input yield input
@ -4064,7 +4062,7 @@ class RunnableGenerator(Runnable[Input, Output]):
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
# Override the default implementation. # Override the default implementation.
# For a runnable generator, we need to bring to provide the # For a runnable generator, we need to bring to provide the
# module of the underlying function when creating the model. # module of the underlying function when creating the model.
@ -4100,7 +4098,7 @@ class RunnableGenerator(Runnable[Input, Output]):
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
# Override the default implementation. # Override the default implementation.
# For a runnable generator, we need to bring to provide the # For a runnable generator, we need to bring to provide the
# module of the underlying function when creating the model. # module of the underlying function when creating the model.
@ -4346,7 +4344,7 @@ class RunnableLambda(Runnable[Input, Output]):
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""The pydantic schema for the input to this Runnable. """The pydantic schema for the input to this Runnable.
Args: Args:
@ -4414,7 +4412,7 @@ class RunnableLambda(Runnable[Input, Output]):
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
# Override the default implementation. # Override the default implementation.
# For a runnable lambda, we need to bring to provide the # For a runnable lambda, we need to bring to provide the
# module of the underlying function when creating the model. # module of the underlying function when creating the model.
@ -4435,7 +4433,7 @@ class RunnableLambda(Runnable[Input, Output]):
) )
@property @property
def deps(self) -> List[Runnable]: def deps(self) -> list[Runnable]:
"""The dependencies of this Runnable. """The dependencies of this Runnable.
Returns: Returns:
@ -4449,7 +4447,7 @@ class RunnableLambda(Runnable[Input, Output]):
else: else:
objects = [] objects = []
deps: List[Runnable] = [] deps: list[Runnable] = []
for obj in objects: for obj in objects:
if isinstance(obj, Runnable): if isinstance(obj, Runnable):
deps.append(obj) deps.append(obj)
@ -4458,7 +4456,7 @@ class RunnableLambda(Runnable[Input, Output]):
return deps return deps
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> list[ConfigurableFieldSpec]:
return get_unique_config_specs( return get_unique_config_specs(
spec for dep in self.deps for spec in dep.config_specs spec for dep in self.deps for spec in dep.config_specs
) )
@ -4944,7 +4942,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
return create_model_v2( return create_model_v2(
self.get_name("Input"), self.get_name("Input"),
root=( root=(
@ -4962,12 +4960,12 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
) )
@property @property
def OutputType(self) -> Type[List[Output]]: def OutputType(self) -> type[list[Output]]:
return List[self.bound.OutputType] # type: ignore[name-defined] return List[self.bound.OutputType] # type: ignore[name-defined]
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
schema = self.bound.get_output_schema(config) schema = self.bound.get_output_schema(config)
return create_model_v2( return create_model_v2(
self.get_name("Output"), self.get_name("Output"),
@ -4983,7 +4981,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
) )
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> list[ConfigurableFieldSpec]:
return self.bound.config_specs return self.bound.config_specs
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
@ -4994,42 +4992,42 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
def _invoke( def _invoke(
self, self,
inputs: List[Input], inputs: list[Input],
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
) -> List[Output]: ) -> list[Output]:
configs = [ configs = [
patch_config(config, callbacks=run_manager.get_child()) for _ in inputs patch_config(config, callbacks=run_manager.get_child()) for _ in inputs
] ]
return self.bound.batch(inputs, configs, **kwargs) return self.bound.batch(inputs, configs, **kwargs)
def invoke( def invoke(
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any self, input: list[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Output]: ) -> list[Output]:
return self._call_with_config(self._invoke, input, config, **kwargs) return self._call_with_config(self._invoke, input, config, **kwargs)
async def _ainvoke( async def _ainvoke(
self, self,
inputs: List[Input], inputs: list[Input],
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
) -> List[Output]: ) -> list[Output]:
configs = [ configs = [
patch_config(config, callbacks=run_manager.get_child()) for _ in inputs patch_config(config, callbacks=run_manager.get_child()) for _ in inputs
] ]
return await self.bound.abatch(inputs, configs, **kwargs) return await self.bound.abatch(inputs, configs, **kwargs)
async def ainvoke( async def ainvoke(
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any self, input: list[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Output]: ) -> list[Output]:
return await self._acall_with_config(self._ainvoke, input, config, **kwargs) return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
async def astream_events( async def astream_events(
@ -5074,7 +5072,7 @@ class RunnableEach(RunnableEachBase[Input, Output]):
""" """
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
@ -5181,7 +5179,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
config: RunnableConfig = Field(default_factory=dict) config: RunnableConfig = Field(default_factory=dict)
"""The config to bind to the underlying Runnable.""" """The config to bind to the underlying Runnable."""
config_factories: List[Callable[[RunnableConfig], RunnableConfig]] = Field( config_factories: list[Callable[[RunnableConfig], RunnableConfig]] = Field(
default_factory=list default_factory=list
) )
"""The config factories to bind to the underlying Runnable.""" """The config factories to bind to the underlying Runnable."""
@ -5210,10 +5208,10 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
kwargs: Optional[Mapping[str, Any]] = None, kwargs: Optional[Mapping[str, Any]] = None,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
config_factories: Optional[ config_factories: Optional[
List[Callable[[RunnableConfig], RunnableConfig]] list[Callable[[RunnableConfig], RunnableConfig]]
] = None, ] = None,
custom_input_type: Optional[Union[Type[Input], BaseModel]] = None, custom_input_type: Optional[Union[type[Input], BaseModel]] = None,
custom_output_type: Optional[Union[Type[Output], BaseModel]] = None, custom_output_type: Optional[Union[type[Output], BaseModel]] = None,
**other_kwargs: Any, **other_kwargs: Any,
) -> None: ) -> None:
"""Create a RunnableBinding from a Runnable and kwargs. """Create a RunnableBinding from a Runnable and kwargs.
@ -5255,7 +5253,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
return self.bound.get_name(suffix, name=name) return self.bound.get_name(suffix, name=name)
@property @property
def InputType(self) -> Type[Input]: def InputType(self) -> type[Input]:
return ( return (
cast(Type[Input], self.custom_input_type) cast(Type[Input], self.custom_input_type)
if self.custom_input_type is not None if self.custom_input_type is not None
@ -5263,7 +5261,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
) )
@property @property
def OutputType(self) -> Type[Output]: def OutputType(self) -> type[Output]:
return ( return (
cast(Type[Output], self.custom_output_type) cast(Type[Output], self.custom_output_type)
if self.custom_output_type is not None if self.custom_output_type is not None
@ -5272,20 +5270,20 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
if self.custom_input_type is not None: if self.custom_input_type is not None:
return super().get_input_schema(config) return super().get_input_schema(config)
return self.bound.get_input_schema(merge_configs(self.config, config)) return self.bound.get_input_schema(merge_configs(self.config, config))
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
if self.custom_output_type is not None: if self.custom_output_type is not None:
return super().get_output_schema(config) return super().get_output_schema(config)
return self.bound.get_output_schema(merge_configs(self.config, config)) return self.bound.get_output_schema(merge_configs(self.config, config))
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> list[ConfigurableFieldSpec]:
return self.bound.config_specs return self.bound.config_specs
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
@ -5296,7 +5294,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
@ -5330,12 +5328,12 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
def batch( def batch(
self, self,
inputs: List[Input], inputs: list[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> list[Output]:
if isinstance(config, list): if isinstance(config, list):
configs = cast( configs = cast(
List[RunnableConfig], List[RunnableConfig],
@ -5352,12 +5350,12 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
async def abatch( async def abatch(
self, self,
inputs: List[Input], inputs: list[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> list[Output]:
if isinstance(config, list): if isinstance(config, list):
configs = cast( configs = cast(
List[RunnableConfig], List[RunnableConfig],
@ -5380,7 +5378,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
*, *,
return_exceptions: Literal[False] = False, return_exceptions: Literal[False] = False,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Tuple[int, Output]]: ... ) -> Iterator[tuple[int, Output]]: ...
@overload @overload
def batch_as_completed( def batch_as_completed(
@ -5390,7 +5388,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
*, *,
return_exceptions: Literal[True], return_exceptions: Literal[True],
**kwargs: Any, **kwargs: Any,
) -> Iterator[Tuple[int, Union[Output, Exception]]]: ... ) -> Iterator[tuple[int, Union[Output, Exception]]]: ...
def batch_as_completed( def batch_as_completed(
self, self,
@ -5399,7 +5397,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Iterator[Tuple[int, Union[Output, Exception]]]: ) -> Iterator[tuple[int, Union[Output, Exception]]]:
if isinstance(config, Sequence): if isinstance(config, Sequence):
configs = cast( configs = cast(
List[RunnableConfig], List[RunnableConfig],
@ -5431,7 +5429,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
*, *,
return_exceptions: Literal[False] = False, return_exceptions: Literal[False] = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Output]]: ... ) -> AsyncIterator[tuple[int, Output]]: ...
@overload @overload
def abatch_as_completed( def abatch_as_completed(
@ -5441,7 +5439,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
*, *,
return_exceptions: Literal[True], return_exceptions: Literal[True],
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: ... ) -> AsyncIterator[tuple[int, Union[Output, Exception]]]: ...
async def abatch_as_completed( async def abatch_as_completed(
self, self,
@ -5450,7 +5448,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: ) -> AsyncIterator[tuple[int, Union[Output, Exception]]]:
if isinstance(config, Sequence): if isinstance(config, Sequence):
configs = cast( configs = cast(
List[RunnableConfig], List[RunnableConfig],
@ -5590,7 +5588,7 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
""" """
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
@ -5678,8 +5676,8 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
def with_types( def with_types(
self, self,
input_type: Optional[Union[Type[Input], BaseModel]] = None, input_type: Optional[Union[type[Input], BaseModel]] = None,
output_type: Optional[Union[Type[Output], BaseModel]] = None, output_type: Optional[Union[type[Output], BaseModel]] = None,
) -> Runnable[Input, Output]: ) -> Runnable[Input, Output]:
return self.__class__( return self.__class__(
bound=self.bound, bound=self.bound,

View File

@ -12,7 +12,6 @@ from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Dict,
Generator, Generator,
Iterable, Iterable,
Iterator, Iterator,
@ -56,13 +55,13 @@ class EmptyDict(TypedDict, total=False):
class RunnableConfig(TypedDict, total=False): class RunnableConfig(TypedDict, total=False):
"""Configuration for a Runnable.""" """Configuration for a Runnable."""
tags: List[str] tags: list[str]
""" """
Tags for this call and any sub-calls (eg. a Chain calling an LLM). Tags for this call and any sub-calls (eg. a Chain calling an LLM).
You can use these to filter calls. You can use these to filter calls.
""" """
metadata: Dict[str, Any] metadata: dict[str, Any]
""" """
Metadata for this call and any sub-calls (eg. a Chain calling an LLM). Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
Keys should be strings, values should be JSON-serializable. Keys should be strings, values should be JSON-serializable.
@ -90,7 +89,7 @@ class RunnableConfig(TypedDict, total=False):
Maximum number of times a call can recurse. If not provided, defaults to 25. Maximum number of times a call can recurse. If not provided, defaults to 25.
""" """
configurable: Dict[str, Any] configurable: dict[str, Any]
""" """
Runtime values for attributes previously made configurable on this Runnable, Runtime values for attributes previously made configurable on this Runnable,
or sub-Runnables, through .configurable_fields() or .configurable_alternatives(). or sub-Runnables, through .configurable_fields() or .configurable_alternatives().
@ -205,7 +204,7 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
def get_config_list( def get_config_list(
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]], length: int config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]], length: int
) -> List[RunnableConfig]: ) -> list[RunnableConfig]:
"""Get a list of configs from a single config or a list of configs. """Get a list of configs from a single config or a list of configs.
It is useful for subclasses overriding batch() or abatch(). It is useful for subclasses overriding batch() or abatch().
@ -255,7 +254,7 @@ def patch_config(
recursion_limit: Optional[int] = None, recursion_limit: Optional[int] = None,
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
configurable: Optional[Dict[str, Any]] = None, configurable: Optional[dict[str, Any]] = None,
) -> RunnableConfig: ) -> RunnableConfig:
"""Patch a config with new values. """Patch a config with new values.

View File

@ -8,12 +8,10 @@ from typing import (
Any, Any,
AsyncIterator, AsyncIterator,
Callable, Callable,
Dict,
Iterator, Iterator,
List, List,
Optional, Optional,
Sequence, Sequence,
Tuple,
Type, Type,
Union, Union,
cast, cast,
@ -69,27 +67,27 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
@property @property
def InputType(self) -> Type[Input]: def InputType(self) -> type[Input]:
return self.default.InputType return self.default.InputType
@property @property
def OutputType(self) -> Type[Output]: def OutputType(self) -> type[Output]:
return self.default.OutputType return self.default.OutputType
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
runnable, config = self.prepare(config) runnable, config = self.prepare(config)
return runnable.get_input_schema(config) return runnable.get_input_schema(config)
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
runnable, config = self.prepare(config) runnable, config = self.prepare(config)
return runnable.get_output_schema(config) return runnable.get_output_schema(config)
@ -109,7 +107,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
def prepare( def prepare(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Tuple[Runnable[Input, Output], RunnableConfig]: ) -> tuple[Runnable[Input, Output], RunnableConfig]:
"""Prepare the Runnable for invocation. """Prepare the Runnable for invocation.
Args: Args:
@ -127,7 +125,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
@abstractmethod @abstractmethod
def _prepare( def _prepare(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Tuple[Runnable[Input, Output], RunnableConfig]: ... ) -> tuple[Runnable[Input, Output], RunnableConfig]: ...
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
@ -143,12 +141,12 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
def batch( def batch(
self, self,
inputs: List[Input], inputs: list[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> list[Output]:
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
prepared = [self.prepare(c) for c in configs] prepared = [self.prepare(c) for c in configs]
@ -164,7 +162,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
return [] return []
def invoke( def invoke(
prepared: Tuple[Runnable[Input, Output], RunnableConfig], prepared: tuple[Runnable[Input, Output], RunnableConfig],
input: Input, input: Input,
) -> Union[Output, Exception]: ) -> Union[Output, Exception]:
bound, config = prepared bound, config = prepared
@ -185,12 +183,12 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
async def abatch( async def abatch(
self, self,
inputs: List[Input], inputs: list[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> list[Output]:
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
prepared = [self.prepare(c) for c in configs] prepared = [self.prepare(c) for c in configs]
@ -206,7 +204,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
return [] return []
async def ainvoke( async def ainvoke(
prepared: Tuple[Runnable[Input, Output], RunnableConfig], prepared: tuple[Runnable[Input, Output], RunnableConfig],
input: Input, input: Input,
) -> Union[Output, Exception]: ) -> Union[Output, Exception]:
bound, config = prepared bound, config = prepared
@ -362,15 +360,15 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
) )
""" """
fields: Dict[str, AnyConfigurableField] fields: dict[str, AnyConfigurableField]
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> list[ConfigurableFieldSpec]:
"""Get the configuration specs for the RunnableConfigurableFields. """Get the configuration specs for the RunnableConfigurableFields.
Returns: Returns:
@ -412,7 +410,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
def _prepare( def _prepare(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Tuple[Runnable[Input, Output], RunnableConfig]: ) -> tuple[Runnable[Input, Output], RunnableConfig]:
config = ensure_config(config) config = ensure_config(config)
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()} specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
configurable_fields = { configurable_fields = {
@ -467,7 +465,7 @@ _enums_for_spec: WeakValueDictionary[
Union[ Union[
ConfigurableFieldSingleOption, ConfigurableFieldMultiOption, ConfigurableField ConfigurableFieldSingleOption, ConfigurableFieldMultiOption, ConfigurableField
], ],
Type[StrEnum], type[StrEnum],
] = WeakValueDictionary() ] = WeakValueDictionary()
_enums_for_spec_lock = threading.Lock() _enums_for_spec_lock = threading.Lock()
@ -532,7 +530,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
which: ConfigurableField which: ConfigurableField
"""The ConfigurableField to use to choose between alternatives.""" """The ConfigurableField to use to choose between alternatives."""
alternatives: Dict[ alternatives: dict[
str, str,
Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]], Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]],
] ]
@ -547,12 +545,12 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
the alternative named "gpt3" becomes "model==gpt3/temperature".""" the alternative named "gpt3" becomes "model==gpt3/temperature"."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> list[ConfigurableFieldSpec]:
with _enums_for_spec_lock: with _enums_for_spec_lock:
if which_enum := _enums_for_spec.get(self.which): if which_enum := _enums_for_spec.get(self.which):
pass pass
@ -612,7 +610,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
def _prepare( def _prepare(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Tuple[Runnable[Input, Output], RunnableConfig]: ) -> tuple[Runnable[Input, Output], RunnableConfig]:
config = ensure_config(config) config = ensure_config(config)
which = config.get("configurable", {}).get(self.which.id, self.default_key) which = config.get("configurable", {}).get(self.which.id, self.default_key)
# remap configurable keys for the chosen alternative # remap configurable keys for the chosen alternative

View File

@ -8,14 +8,10 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable, Callable,
Dict,
List,
NamedTuple, NamedTuple,
Optional, Optional,
Protocol, Protocol,
Sequence, Sequence,
Tuple,
Type,
TypedDict, TypedDict,
Union, Union,
overload, overload,
@ -106,8 +102,8 @@ class Node(NamedTuple):
id: str id: str
name: str name: str
data: Union[Type[BaseModel], RunnableType] data: Union[type[BaseModel], RunnableType]
metadata: Optional[Dict[str, Any]] metadata: Optional[dict[str, Any]]
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node: def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
"""Return a copy of the node with optional new id and name. """Return a copy of the node with optional new id and name.
@ -179,7 +175,7 @@ class MermaidDrawMethod(Enum):
API = "api" # Uses Mermaid.INK API to render the graph API = "api" # Uses Mermaid.INK API to render the graph
def node_data_str(id: str, data: Union[Type[BaseModel], RunnableType]) -> str: def node_data_str(id: str, data: Union[type[BaseModel], RunnableType]) -> str:
"""Convert the data of a node to a string. """Convert the data of a node to a string.
Args: Args:
@ -202,7 +198,7 @@ def node_data_str(id: str, data: Union[Type[BaseModel], RunnableType]) -> str:
def node_data_json( def node_data_json(
node: Node, *, with_schemas: bool = False node: Node, *, with_schemas: bool = False
) -> Dict[str, Union[str, Dict[str, Any]]]: ) -> dict[str, Union[str, dict[str, Any]]]:
"""Convert the data of a node to a JSON-serializable format. """Convert the data of a node to a JSON-serializable format.
Args: Args:
@ -217,7 +213,7 @@ def node_data_json(
from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.base import Runnable, RunnableSerializable
if isinstance(node.data, RunnableSerializable): if isinstance(node.data, RunnableSerializable):
json: Dict[str, Any] = { json: dict[str, Any] = {
"type": "runnable", "type": "runnable",
"data": { "data": {
"id": node.data.lc_id(), "id": node.data.lc_id(),
@ -265,10 +261,10 @@ class Graph:
edges: List of edges in the graph. Defaults to an empty list. edges: List of edges in the graph. Defaults to an empty list.
""" """
nodes: Dict[str, Node] = field(default_factory=dict) nodes: dict[str, Node] = field(default_factory=dict)
edges: List[Edge] = field(default_factory=list) edges: list[Edge] = field(default_factory=list)
def to_json(self, *, with_schemas: bool = False) -> Dict[str, List[Dict[str, Any]]]: def to_json(self, *, with_schemas: bool = False) -> dict[str, list[dict[str, Any]]]:
"""Convert the graph to a JSON-serializable format. """Convert the graph to a JSON-serializable format.
Args: Args:
@ -282,7 +278,7 @@ class Graph:
node.id: i if is_uuid(node.id) else node.id node.id: i if is_uuid(node.id) else node.id
for i, node in enumerate(self.nodes.values()) for i, node in enumerate(self.nodes.values())
} }
edges: List[Dict[str, Any]] = [] edges: list[dict[str, Any]] = []
for edge in self.edges: for edge in self.edges:
edge_dict = { edge_dict = {
"source": stable_node_ids[edge.source], "source": stable_node_ids[edge.source],
@ -315,10 +311,10 @@ class Graph:
def add_node( def add_node(
self, self,
data: Union[Type[BaseModel], RunnableType], data: Union[type[BaseModel], RunnableType],
id: Optional[str] = None, id: Optional[str] = None,
*, *,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
) -> Node: ) -> Node:
"""Add a node to the graph and return it. """Add a node to the graph and return it.
@ -386,7 +382,7 @@ class Graph:
def extend( def extend(
self, graph: Graph, *, prefix: str = "" self, graph: Graph, *, prefix: str = ""
) -> Tuple[Optional[Node], Optional[Node]]: ) -> tuple[Optional[Node], Optional[Node]]:
"""Add all nodes and edges from another graph. """Add all nodes and edges from another graph.
Note this doesn't check for duplicates, nor does it connect the graphs. Note this doesn't check for duplicates, nor does it connect the graphs.
@ -622,7 +618,7 @@ def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
If there is no such node, or there are multiple, return None. If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the origin.""" When drawing the graph, this node would be the origin."""
targets = {edge.target for edge in graph.edges if edge.source not in exclude} targets = {edge.target for edge in graph.edges if edge.source not in exclude}
found: List[Node] = [] found: list[Node] = []
for node in graph.nodes.values(): for node in graph.nodes.values():
if node.id not in exclude and node.id not in targets: if node.id not in exclude and node.id not in targets:
found.append(node) found.append(node)
@ -635,7 +631,7 @@ def _last_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
If there is no such node, or there are multiple, return None. If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the destination.""" When drawing the graph, this node would be the destination."""
sources = {edge.source for edge in graph.edges if edge.target not in exclude} sources = {edge.source for edge in graph.edges if edge.target not in exclude}
found: List[Node] = [] found: list[Node] = []
for node in graph.nodes.values(): for node in graph.nodes.values():
if node.id not in exclude and node.id not in sources: if node.id not in exclude and node.id not in sources:
found.append(node) found.append(node)

View File

@ -6,10 +6,8 @@ from typing import (
Any, Any,
Callable, Callable,
Dict, Dict,
List,
Optional, Optional,
Sequence, Sequence,
Type,
Union, Union,
) )
@ -238,7 +236,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
history_factory_config: Sequence[ConfigurableFieldSpec] history_factory_config: Sequence[ConfigurableFieldSpec]
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
@ -366,7 +364,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
self._history_chain = history_chain self._history_chain = history_chain
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> list[ConfigurableFieldSpec]:
"""Get the configuration specs for the RunnableWithMessageHistory.""" """Get the configuration specs for the RunnableWithMessageHistory."""
return get_unique_config_specs( return get_unique_config_specs(
super().config_specs + list(self.history_factory_config) super().config_specs + list(self.history_factory_config)
@ -374,10 +372,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
fields: Dict = {} fields: dict = {}
if self.input_messages_key and self.history_messages_key: if self.input_messages_key and self.history_messages_key:
fields[self.input_messages_key] = ( fields[self.input_messages_key] = (
Union[str, BaseMessage, Sequence[BaseMessage]], Union[str, BaseMessage, Sequence[BaseMessage]],
@ -398,13 +396,13 @@ class RunnableWithMessageHistory(RunnableBindingBase):
) )
@property @property
def OutputType(self) -> Type[Output]: def OutputType(self) -> type[Output]:
output_type = self._history_chain.OutputType output_type = self._history_chain.OutputType
return output_type return output_type
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""Get a pydantic model that can be used to validate output to the Runnable. """Get a pydantic model that can be used to validate output to the Runnable.
Runnables that leverage the configurable_fields and configurable_alternatives Runnables that leverage the configurable_fields and configurable_alternatives
@ -430,15 +428,15 @@ class RunnableWithMessageHistory(RunnableBindingBase):
module_name=self.__class__.__module__, module_name=self.__class__.__module__,
) )
def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool: def _is_not_async(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> bool:
return False return False
async def _is_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool: async def _is_async(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> bool:
return True return True
def _get_input_messages( def _get_input_messages(
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
) -> List[BaseMessage]: ) -> list[BaseMessage]:
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
# If dictionary, try to pluck the single key representing messages # If dictionary, try to pluck the single key representing messages
@ -481,7 +479,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
def _get_output_messages( def _get_output_messages(
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
) -> List[BaseMessage]: ) -> list[BaseMessage]:
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
# If dictionary, try to pluck the single key representing messages # If dictionary, try to pluck the single key representing messages
@ -514,7 +512,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
f"Got {output_val}." f"Got {output_val}."
) )
def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]: def _enter_history(self, input: Any, config: RunnableConfig) -> list[BaseMessage]:
hist: BaseChatMessageHistory = config["configurable"]["message_history"] hist: BaseChatMessageHistory = config["configurable"]["message_history"]
messages = hist.messages.copy() messages = hist.messages.copy()
@ -527,8 +525,8 @@ class RunnableWithMessageHistory(RunnableBindingBase):
return messages return messages
async def _aenter_history( async def _aenter_history(
self, input: Dict[str, Any], config: RunnableConfig self, input: dict[str, Any], config: RunnableConfig
) -> List[BaseMessage]: ) -> list[BaseMessage]:
hist: BaseChatMessageHistory = config["configurable"]["message_history"] hist: BaseChatMessageHistory = config["configurable"]["message_history"]
messages = (await hist.aget_messages()).copy() messages = (await hist.aget_messages()).copy()
@ -621,7 +619,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
return config return config
def _get_parameter_names(callable_: GetSessionHistoryCallable) -> List[str]: def _get_parameter_names(callable_: GetSessionHistoryCallable) -> list[str]:
"""Get the parameter names of the callable.""" """Get the parameter names of the callable."""
sig = inspect.signature(callable_) sig = inspect.signature(callable_)
return list(sig.parameters.keys()) return list(sig.parameters.keys())

View File

@ -13,10 +13,8 @@ from typing import (
Callable, Callable,
Dict, Dict,
Iterator, Iterator,
List,
Mapping, Mapping,
Optional, Optional,
Type,
Union, Union,
cast, cast,
) )
@ -144,7 +142,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
# {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20} # {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20}
""" """
input_type: Optional[Type[Other]] = None input_type: Optional[type[Other]] = None
func: Optional[ func: Optional[
Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]] Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]]
@ -180,7 +178,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
] ]
] = None, ] = None,
*, *,
input_type: Optional[Type[Other]] = None, input_type: Optional[type[Other]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
if inspect.iscoroutinefunction(func): if inspect.iscoroutinefunction(func):
@ -194,7 +192,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
@ -210,11 +208,11 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
def assign( def assign(
cls, cls,
**kwargs: Union[ **kwargs: Union[
Runnable[Dict[str, Any], Any], Runnable[dict[str, Any], Any],
Callable[[Dict[str, Any]], Any], Callable[[dict[str, Any]], Any],
Mapping[ Mapping[
str, str,
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]], Union[Runnable[dict[str, Any], Any], Callable[[dict[str, Any]], Any]],
], ],
], ],
) -> RunnableAssign: ) -> RunnableAssign:
@ -228,7 +226,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
A Runnable that merges the Dict input with the output produced by the A Runnable that merges the Dict input with the output produced by the
mapping argument. mapping argument.
""" """
return RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs)) return RunnableAssign(RunnableParallel[dict[str, Any]](kwargs))
def invoke( def invoke(
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
@ -392,9 +390,9 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
# returns {'input': 5, 'add_step': {'added': 15}} # returns {'input': 5, 'add_step': {'added': 15}}
""" """
mapper: RunnableParallel[Dict[str, Any]] mapper: RunnableParallel[dict[str, Any]]
def __init__(self, mapper: RunnableParallel[Dict[str, Any]], **kwargs: Any) -> None: def __init__(self, mapper: RunnableParallel[dict[str, Any]], **kwargs: Any) -> None:
super().__init__(mapper=mapper, **kwargs) # type: ignore[call-arg] super().__init__(mapper=mapper, **kwargs) # type: ignore[call-arg]
@classmethod @classmethod
@ -402,7 +400,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
@ -418,7 +416,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
map_input_schema = self.mapper.get_input_schema(config) map_input_schema = self.mapper.get_input_schema(config)
if not issubclass(map_input_schema, RootModel): if not issubclass(map_input_schema, RootModel):
# ie. it's a dict # ie. it's a dict
@ -428,7 +426,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
map_input_schema = self.mapper.get_input_schema(config) map_input_schema = self.mapper.get_input_schema(config)
map_output_schema = self.mapper.get_output_schema(config) map_output_schema = self.mapper.get_output_schema(config)
if not issubclass(map_input_schema, RootModel) and not issubclass( if not issubclass(map_input_schema, RootModel) and not issubclass(
@ -453,7 +451,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
return super().get_output_schema(config) return super().get_output_schema(config)
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> list[ConfigurableFieldSpec]:
return self.mapper.config_specs return self.mapper.config_specs
def get_graph(self, config: RunnableConfig | None = None) -> Graph: def get_graph(self, config: RunnableConfig | None = None) -> Graph:
@ -470,11 +468,11 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def _invoke( def _invoke(
self, self,
input: Dict[str, Any], input: dict[str, Any],
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> dict[str, Any]:
assert isinstance( assert isinstance(
input, dict input, dict
), "The input to RunnablePassthrough.assign() must be a dict." ), "The input to RunnablePassthrough.assign() must be a dict."
@ -490,19 +488,19 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def invoke( def invoke(
self, self,
input: Dict[str, Any], input: dict[str, Any],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> dict[str, Any]:
return self._call_with_config(self._invoke, input, config, **kwargs) return self._call_with_config(self._invoke, input, config, **kwargs)
async def _ainvoke( async def _ainvoke(
self, self,
input: Dict[str, Any], input: dict[str, Any],
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> dict[str, Any]:
assert isinstance( assert isinstance(
input, dict input, dict
), "The input to RunnablePassthrough.assign() must be a dict." ), "The input to RunnablePassthrough.assign() must be a dict."
@ -518,19 +516,19 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
async def ainvoke( async def ainvoke(
self, self,
input: Dict[str, Any], input: dict[str, Any],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> dict[str, Any]:
return await self._acall_with_config(self._ainvoke, input, config, **kwargs) return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
def _transform( def _transform(
self, self,
input: Iterator[Dict[str, Any]], input: Iterator[dict[str, Any]],
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Dict[str, Any]]: ) -> Iterator[dict[str, Any]]:
# collect mapper keys # collect mapper keys
mapper_keys = set(self.mapper.steps__.keys()) mapper_keys = set(self.mapper.steps__.keys())
# create two streams, one for the map and one for the passthrough # create two streams, one for the map and one for the passthrough
@ -572,21 +570,21 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def transform( def transform(
self, self,
input: Iterator[Dict[str, Any]], input: Iterator[dict[str, Any]],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any | None, **kwargs: Any | None,
) -> Iterator[Dict[str, Any]]: ) -> Iterator[dict[str, Any]]:
yield from self._transform_stream_with_config( yield from self._transform_stream_with_config(
input, self._transform, config, **kwargs input, self._transform, config, **kwargs
) )
async def _atransform( async def _atransform(
self, self,
input: AsyncIterator[Dict[str, Any]], input: AsyncIterator[dict[str, Any]],
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]: ) -> AsyncIterator[dict[str, Any]]:
# collect mapper keys # collect mapper keys
mapper_keys = set(self.mapper.steps__.keys()) mapper_keys = set(self.mapper.steps__.keys())
# create two streams, one for the map and one for the passthrough # create two streams, one for the map and one for the passthrough
@ -622,10 +620,10 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
async def atransform( async def atransform(
self, self,
input: AsyncIterator[Dict[str, Any]], input: AsyncIterator[dict[str, Any]],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]: ) -> AsyncIterator[dict[str, Any]]:
async for chunk in self._atransform_stream_with_config( async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, **kwargs input, self._atransform, config, **kwargs
): ):
@ -633,19 +631,19 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def stream( def stream(
self, self,
input: Dict[str, Any], input: dict[str, Any],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Dict[str, Any]]: ) -> Iterator[dict[str, Any]]:
return self.transform(iter([input]), config, **kwargs) return self.transform(iter([input]), config, **kwargs)
async def astream( async def astream(
self, self,
input: Dict[str, Any], input: dict[str, Any],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]: ) -> AsyncIterator[dict[str, Any]]:
async def input_aiter() -> AsyncIterator[Dict[str, Any]]: async def input_aiter() -> AsyncIterator[dict[str, Any]]:
yield input yield input
async for chunk in self.atransform(input_aiter(), config, **kwargs): async for chunk in self.atransform(input_aiter(), config, **kwargs):
@ -683,9 +681,9 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
print(output_data) # Output: {'name': 'John', 'age': 30} print(output_data) # Output: {'name': 'John', 'age': 30}
""" """
keys: Union[str, List[str]] keys: Union[str, list[str]]
def __init__(self, keys: Union[str, List[str]], **kwargs: Any) -> None: def __init__(self, keys: Union[str, list[str]], **kwargs: Any) -> None:
super().__init__(keys=keys, **kwargs) # type: ignore[call-arg] super().__init__(keys=keys, **kwargs) # type: ignore[call-arg]
@classmethod @classmethod
@ -693,7 +691,7 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
@ -707,7 +705,7 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
) )
return super().get_name(suffix, name=name) return super().get_name(suffix, name=name)
def _pick(self, input: Dict[str, Any]) -> Any: def _pick(self, input: dict[str, Any]) -> Any:
assert isinstance( assert isinstance(
input, dict input, dict
), "The input to RunnablePassthrough.assign() must be a dict." ), "The input to RunnablePassthrough.assign() must be a dict."
@ -723,36 +721,36 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def _invoke( def _invoke(
self, self,
input: Dict[str, Any], input: dict[str, Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
return self._pick(input) return self._pick(input)
def invoke( def invoke(
self, self,
input: Dict[str, Any], input: dict[str, Any],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> dict[str, Any]:
return self._call_with_config(self._invoke, input, config, **kwargs) return self._call_with_config(self._invoke, input, config, **kwargs)
async def _ainvoke( async def _ainvoke(
self, self,
input: Dict[str, Any], input: dict[str, Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
return self._pick(input) return self._pick(input)
async def ainvoke( async def ainvoke(
self, self,
input: Dict[str, Any], input: dict[str, Any],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> dict[str, Any]:
return await self._acall_with_config(self._ainvoke, input, config, **kwargs) return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
def _transform( def _transform(
self, self,
input: Iterator[Dict[str, Any]], input: Iterator[dict[str, Any]],
) -> Iterator[Dict[str, Any]]: ) -> Iterator[dict[str, Any]]:
for chunk in input: for chunk in input:
picked = self._pick(chunk) picked = self._pick(chunk)
if picked is not None: if picked is not None:
@ -760,18 +758,18 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def transform( def transform(
self, self,
input: Iterator[Dict[str, Any]], input: Iterator[dict[str, Any]],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Dict[str, Any]]: ) -> Iterator[dict[str, Any]]:
yield from self._transform_stream_with_config( yield from self._transform_stream_with_config(
input, self._transform, config, **kwargs input, self._transform, config, **kwargs
) )
async def _atransform( async def _atransform(
self, self,
input: AsyncIterator[Dict[str, Any]], input: AsyncIterator[dict[str, Any]],
) -> AsyncIterator[Dict[str, Any]]: ) -> AsyncIterator[dict[str, Any]]:
async for chunk in input: async for chunk in input:
picked = self._pick(chunk) picked = self._pick(chunk)
if picked is not None: if picked is not None:
@ -779,10 +777,10 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
async def atransform( async def atransform(
self, self,
input: AsyncIterator[Dict[str, Any]], input: AsyncIterator[dict[str, Any]],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]: ) -> AsyncIterator[dict[str, Any]]:
async for chunk in self._atransform_stream_with_config( async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, **kwargs input, self._atransform, config, **kwargs
): ):
@ -790,19 +788,19 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def stream( def stream(
self, self,
input: Dict[str, Any], input: dict[str, Any],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Dict[str, Any]]: ) -> Iterator[dict[str, Any]]:
return self.transform(iter([input]), config, **kwargs) return self.transform(iter([input]), config, **kwargs)
async def astream( async def astream(
self, self,
input: Dict[str, Any], input: dict[str, Any],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]: ) -> AsyncIterator[dict[str, Any]]:
async def input_aiter() -> AsyncIterator[Dict[str, Any]]: async def input_aiter() -> AsyncIterator[dict[str, Any]]:
yield input yield input
async for chunk in self.atransform(input_aiter(), config, **kwargs): async for chunk in self.atransform(input_aiter(), config, **kwargs):

View File

@ -71,7 +71,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
runnables: Mapping[str, Runnable[Any, Output]] runnables: Mapping[str, Runnable[Any, Output]]
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> list[ConfigurableFieldSpec]:
return get_unique_config_specs( return get_unique_config_specs(
spec for step in self.runnables.values() for spec in step.config_specs spec for step in self.runnables.values() for spec in step.config_specs
) )
@ -94,7 +94,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
@ -125,12 +125,12 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
def batch( def batch(
self, self,
inputs: List[RouterInput], inputs: list[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> list[Output]:
if not inputs: if not inputs:
return [] return []
@ -160,12 +160,12 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
async def abatch( async def abatch(
self, self,
inputs: List[RouterInput], inputs: list[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> list[Output]:
if not inputs: if not inputs:
return [] return []

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Literal, Sequence, Union from typing import Any, Literal, Sequence, Union
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict
@ -110,7 +110,7 @@ class BaseStreamEvent(TypedDict):
Each child Runnable that gets invoked as part of the execution of a parent Runnable Each child Runnable that gets invoked as part of the execution of a parent Runnable
is assigned its own unique ID. is assigned its own unique ID.
""" """
tags: NotRequired[List[str]] tags: NotRequired[list[str]]
"""Tags associated with the Runnable that generated this event. """Tags associated with the Runnable that generated this event.
Tags are always inherited from parent Runnables. Tags are always inherited from parent Runnables.
@ -118,7 +118,7 @@ class BaseStreamEvent(TypedDict):
Tags can either be bound to a Runnable using `.with_config({"tags": ["hello"]})` Tags can either be bound to a Runnable using `.with_config({"tags": ["hello"]})`
or passed at run time using `.astream_events(..., {"tags": ["hello"]})`. or passed at run time using `.astream_events(..., {"tags": ["hello"]})`.
""" """
metadata: NotRequired[Dict[str, Any]] metadata: NotRequired[dict[str, Any]]
"""Metadata associated with the Runnable that generated this event. """Metadata associated with the Runnable that generated this event.
Metadata can either be bound to a Runnable using Metadata can either be bound to a Runnable using

View File

@ -18,13 +18,11 @@ from typing import (
Coroutine, Coroutine,
Dict, Dict,
Iterable, Iterable,
List,
Mapping, Mapping,
NamedTuple, NamedTuple,
Optional, Optional,
Protocol, Protocol,
Sequence, Sequence,
Set,
TypeVar, TypeVar,
Union, Union,
) )
@ -126,7 +124,7 @@ def asyncio_accepts_context() -> bool:
class IsLocalDict(ast.NodeVisitor): class IsLocalDict(ast.NodeVisitor):
"""Check if a name is a local dict.""" """Check if a name is a local dict."""
def __init__(self, name: str, keys: Set[str]) -> None: def __init__(self, name: str, keys: set[str]) -> None:
"""Initialize the visitor. """Initialize the visitor.
Args: Args:
@ -181,7 +179,7 @@ class IsFunctionArgDict(ast.NodeVisitor):
"""Check if the first argument of a function is a dict.""" """Check if the first argument of a function is a dict."""
def __init__(self) -> None: def __init__(self) -> None:
self.keys: Set[str] = set() self.keys: set[str] = set()
def visit_Lambda(self, node: ast.Lambda) -> Any: def visit_Lambda(self, node: ast.Lambda) -> Any:
"""Visit a lambda function. """Visit a lambda function.
@ -230,8 +228,8 @@ class NonLocals(ast.NodeVisitor):
"""Get nonlocal variables accessed.""" """Get nonlocal variables accessed."""
def __init__(self) -> None: def __init__(self) -> None:
self.loads: Set[str] = set() self.loads: set[str] = set()
self.stores: Set[str] = set() self.stores: set[str] = set()
def visit_Name(self, node: ast.Name) -> Any: def visit_Name(self, node: ast.Name) -> Any:
"""Visit a name node. """Visit a name node.
@ -271,7 +269,7 @@ class FunctionNonLocals(ast.NodeVisitor):
"""Get the nonlocal variables accessed of a function.""" """Get the nonlocal variables accessed of a function."""
def __init__(self) -> None: def __init__(self) -> None:
self.nonlocals: Set[str] = set() self.nonlocals: set[str] = set()
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
"""Visit a function definition. """Visit a function definition.
@ -335,7 +333,7 @@ class GetLambdaSource(ast.NodeVisitor):
self.source = ast.unparse(node) self.source = ast.unparse(node)
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]: def get_function_first_arg_dict_keys(func: Callable) -> Optional[list[str]]:
"""Get the keys of the first argument of a function if it is a dict. """Get the keys of the first argument of a function if it is a dict.
Args: Args:
@ -378,7 +376,7 @@ def get_lambda_source(func: Callable) -> Optional[str]:
return name return name
def get_function_nonlocals(func: Callable) -> List[Any]: def get_function_nonlocals(func: Callable) -> list[Any]:
"""Get the nonlocal variables accessed by a function. """Get the nonlocal variables accessed by a function.
Args: Args:
@ -392,7 +390,7 @@ def get_function_nonlocals(func: Callable) -> List[Any]:
tree = ast.parse(textwrap.dedent(code)) tree = ast.parse(textwrap.dedent(code))
visitor = FunctionNonLocals() visitor = FunctionNonLocals()
visitor.visit(tree) visitor.visit(tree)
values: List[Any] = [] values: list[Any] = []
closure = inspect.getclosurevars(func) closure = inspect.getclosurevars(func)
candidates = {**closure.globals, **closure.nonlocals} candidates = {**closure.globals, **closure.nonlocals}
for k, v in candidates.items(): for k, v in candidates.items():
@ -608,12 +606,12 @@ class ConfigurableFieldSpec(NamedTuple):
description: Optional[str] = None description: Optional[str] = None
default: Any = None default: Any = None
is_shared: bool = False is_shared: bool = False
dependencies: Optional[List[str]] = None dependencies: Optional[list[str]] = None
def get_unique_config_specs( def get_unique_config_specs(
specs: Iterable[ConfigurableFieldSpec], specs: Iterable[ConfigurableFieldSpec],
) -> List[ConfigurableFieldSpec]: ) -> list[ConfigurableFieldSpec]:
"""Get the unique config specs from a sequence of config specs. """Get the unique config specs from a sequence of config specs.
Args: Args:
@ -628,7 +626,7 @@ def get_unique_config_specs(
grouped = groupby( grouped = groupby(
sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id
) )
unique: List[ConfigurableFieldSpec] = [] unique: list[ConfigurableFieldSpec] = []
for id, dupes in grouped: for id, dupes in grouped:
first = next(dupes) first = next(dupes)
others = list(dupes) others = list(dupes)

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from typing import Any, List, Optional, Sequence, Union from typing import Any, Optional, Sequence, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -142,10 +142,10 @@ class Operation(FilterDirective):
""" """
operator: Operator operator: Operator
arguments: List[FilterDirective] arguments: list[FilterDirective]
def __init__( def __init__(
self, operator: Operator, arguments: List[FilterDirective], **kwargs: Any self, operator: Operator, arguments: list[FilterDirective], **kwargs: Any
) -> None: ) -> None:
# super exists from BaseModel # super exists from BaseModel
super().__init__( # type: ignore[call-arg] super().__init__( # type: ignore[call-arg]

View File

@ -13,12 +13,9 @@ from typing import (
Any, Any,
Callable, Callable,
Dict, Dict,
List,
Literal, Literal,
Optional, Optional,
Sequence, Sequence,
Tuple,
Type,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@ -78,11 +75,11 @@ class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation.""" """Raised when 'args_schema' is missing or has an incorrect type annotation."""
def _is_annotated_type(typ: Type[Any]) -> bool: def _is_annotated_type(typ: type[Any]) -> bool:
return get_origin(typ) is Annotated return get_origin(typ) is Annotated
def _get_annotation_description(arg_type: Type) -> str | None: def _get_annotation_description(arg_type: type) -> str | None:
if _is_annotated_type(arg_type): if _is_annotated_type(arg_type):
annotated_args = get_args(arg_type) annotated_args = get_args(arg_type)
for annotation in annotated_args[1:]: for annotation in annotated_args[1:]:
@ -92,7 +89,7 @@ def _get_annotation_description(arg_type: Type) -> str | None:
def _get_filtered_args( def _get_filtered_args(
inferred_model: Type[BaseModel], inferred_model: type[BaseModel],
func: Callable, func: Callable,
*, *,
filter_args: Sequence[str], filter_args: Sequence[str],
@ -112,7 +109,7 @@ def _get_filtered_args(
def _parse_python_function_docstring( def _parse_python_function_docstring(
function: Callable, annotations: dict, error_on_invalid_docstring: bool = False function: Callable, annotations: dict, error_on_invalid_docstring: bool = False
) -> Tuple[str, dict]: ) -> tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function. """Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide. Assumes the function docstring follows Google Python style guide.
@ -141,7 +138,7 @@ def _infer_arg_descriptions(
*, *,
parse_docstring: bool = False, parse_docstring: bool = False,
error_on_invalid_docstring: bool = False, error_on_invalid_docstring: bool = False,
) -> Tuple[str, dict]: ) -> tuple[str, dict]:
"""Infer argument descriptions from a function's docstring.""" """Infer argument descriptions from a function's docstring."""
if hasattr(inspect, "get_annotations"): if hasattr(inspect, "get_annotations"):
# This is for python < 3.10 # This is for python < 3.10
@ -218,7 +215,7 @@ def create_schema_from_function(
parse_docstring: bool = False, parse_docstring: bool = False,
error_on_invalid_docstring: bool = False, error_on_invalid_docstring: bool = False,
include_injected: bool = True, include_injected: bool = True,
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""Create a pydantic schema from a function's signature. """Create a pydantic schema from a function's signature.
Args: Args:
@ -273,7 +270,7 @@ def create_schema_from_function(
filter_args_ = filter_args filter_args_ = filter_args
else: else:
# Handle classmethods and instance methods # Handle classmethods and instance methods
existing_params: List[str] = list(sig.parameters.keys()) existing_params: list[str] = list(sig.parameters.keys())
if existing_params and existing_params[0] in ("self", "cls") and in_class: if existing_params and existing_params[0] in ("self", "cls") and in_class:
filter_args_ = [existing_params[0]] + list(FILTERED_ARGS) filter_args_ = [existing_params[0]] + list(FILTERED_ARGS)
else: else:
@ -395,13 +392,13 @@ class ChildTool(BaseTool):
description="Callback manager to add to the run trace.", description="Callback manager to add to the run trace.",
) )
) )
tags: Optional[List[str]] = None tags: Optional[list[str]] = None
"""Optional list of tags associated with the tool. Defaults to None. """Optional list of tags associated with the tool. Defaults to None.
These tags will be associated with each call to this tool, These tags will be associated with each call to this tool,
and passed as arguments to the handlers defined in `callbacks`. and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a tool with its use case. You can use these to eg identify a specific instance of a tool with its use case.
""" """
metadata: Optional[Dict[str, Any]] = None metadata: Optional[dict[str, Any]] = None
"""Optional metadata associated with the tool. Defaults to None. """Optional metadata associated with the tool. Defaults to None.
This metadata will be associated with each call to this tool, This metadata will be associated with each call to this tool,
and passed as arguments to the handlers defined in `callbacks`. and passed as arguments to the handlers defined in `callbacks`.
@ -451,7 +448,7 @@ class ChildTool(BaseTool):
return self.get_input_schema().model_json_schema()["properties"] return self.get_input_schema().model_json_schema()["properties"]
@property @property
def tool_call_schema(self) -> Type[BaseModel]: def tool_call_schema(self) -> type[BaseModel]:
full_schema = self.get_input_schema() full_schema = self.get_input_schema()
fields = [] fields = []
for name, type_ in _get_all_basemodel_annotations(full_schema).items(): for name, type_ in _get_all_basemodel_annotations(full_schema).items():
@ -465,7 +462,7 @@ class ChildTool(BaseTool):
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""The tool's input schema. """The tool's input schema.
Args: Args:
@ -481,7 +478,7 @@ class ChildTool(BaseTool):
def invoke( def invoke(
self, self,
input: Union[str, Dict, ToolCall], input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
@ -490,7 +487,7 @@ class ChildTool(BaseTool):
async def ainvoke( async def ainvoke(
self, self,
input: Union[str, Dict, ToolCall], input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
@ -499,7 +496,7 @@ class ChildTool(BaseTool):
# --- Tool --- # --- Tool ---
def _parse_input(self, tool_input: Union[str, Dict]) -> Union[str, Dict[str, Any]]: def _parse_input(self, tool_input: Union[str, dict]) -> Union[str, dict[str, Any]]:
"""Convert tool input to a pydantic model. """Convert tool input to a pydantic model.
Args: Args:
@ -536,7 +533,7 @@ class ChildTool(BaseTool):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def raise_deprecation(cls, values: Dict) -> Any: def raise_deprecation(cls, values: dict) -> Any:
"""Raise deprecation warning if callback_manager is used. """Raise deprecation warning if callback_manager is used.
Args: Args:
@ -574,7 +571,7 @@ class ChildTool(BaseTool):
kwargs["run_manager"] = kwargs["run_manager"].get_sync() kwargs["run_manager"] = kwargs["run_manager"].get_sync()
return await run_in_executor(None, self._run, *args, **kwargs) return await run_in_executor(None, self._run, *args, **kwargs)
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]:
tool_input = self._parse_input(tool_input) tool_input = self._parse_input(tool_input)
# For backwards compatibility, if run_input is a string, # For backwards compatibility, if run_input is a string,
# pass as a positional argument. # pass as a positional argument.
@ -585,14 +582,14 @@ class ChildTool(BaseTool):
def run( def run(
self, self,
tool_input: Union[str, Dict[str, Any]], tool_input: Union[str, dict[str, Any]],
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
start_color: Optional[str] = "green", start_color: Optional[str] = "green",
color: Optional[str] = "green", color: Optional[str] = "green",
callbacks: Callbacks = None, callbacks: Callbacks = None,
*, *,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None, run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
@ -696,14 +693,14 @@ class ChildTool(BaseTool):
async def arun( async def arun(
self, self,
tool_input: Union[str, Dict], tool_input: Union[str, dict],
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
start_color: Optional[str] = "green", start_color: Optional[str] = "green",
color: Optional[str] = "green", color: Optional[str] = "green",
callbacks: Callbacks = None, callbacks: Callbacks = None,
*, *,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None, run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
@ -866,7 +863,7 @@ def _prep_run_args(
input: Union[str, dict, ToolCall], input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
**kwargs: Any, **kwargs: Any,
) -> Tuple[Union[str, Dict], Dict]: ) -> tuple[Union[str, dict], dict]:
config = ensure_config(config) config = ensure_config(config)
if _is_tool_call(input): if _is_tool_call(input):
tool_call_id: Optional[str] = cast(ToolCall, input)["id"] tool_call_id: Optional[str] = cast(ToolCall, input)["id"]
@ -933,7 +930,7 @@ def _stringify(content: Any) -> str:
return str(content) return str(content)
def _get_type_hints(func: Callable) -> Optional[Dict[str, Type]]: def _get_type_hints(func: Callable) -> Optional[dict[str, type]]:
if isinstance(func, functools.partial): if isinstance(func, functools.partial):
func = func.func func = func.func
try: try:
@ -956,7 +953,7 @@ class InjectedToolArg:
"""Annotation for a Tool arg that is **not** meant to be generated by a model.""" """Annotation for a Tool arg that is **not** meant to be generated by a model."""
def _is_injected_arg_type(type_: Type) -> bool: def _is_injected_arg_type(type_: type) -> bool:
return any( return any(
isinstance(arg, InjectedToolArg) isinstance(arg, InjectedToolArg)
or (isinstance(arg, type) and issubclass(arg, InjectedToolArg)) or (isinstance(arg, type) and issubclass(arg, InjectedToolArg))
@ -966,10 +963,10 @@ def _is_injected_arg_type(type_: Type) -> bool:
def _get_all_basemodel_annotations( def _get_all_basemodel_annotations(
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
) -> Dict[str, Type]: ) -> dict[str, type]:
# cls has no subscript: cls = FooBar # cls has no subscript: cls = FooBar
if isinstance(cls, type): if isinstance(cls, type):
annotations: Dict[str, Type] = {} annotations: dict[str, type] = {}
for name, param in inspect.signature(cls).parameters.items(): for name, param in inspect.signature(cls).parameters.items():
# Exclude hidden init args added by pydantic Config. For example if # Exclude hidden init args added by pydantic Config. For example if
# BaseModel(extra="allow") then "extra_data" will part of init sig. # BaseModel(extra="allow") then "extra_data" will part of init sig.
@ -979,7 +976,7 @@ def _get_all_basemodel_annotations(
) and name not in fields: ) and name not in fields:
continue continue
annotations[name] = param.annotation annotations[name] = param.annotation
orig_bases: Tuple = getattr(cls, "__orig_bases__", tuple()) orig_bases: tuple = getattr(cls, "__orig_bases__", tuple())
# cls has subscript: cls = FooBar[int] # cls has subscript: cls = FooBar[int]
else: else:
annotations = _get_all_basemodel_annotations( annotations = _get_all_basemodel_annotations(
@ -1011,7 +1008,7 @@ def _get_all_basemodel_annotations(
# parent_origin = Baz, # parent_origin = Baz,
# generic_type_vars = (type vars in Baz) # generic_type_vars = (type vars in Baz)
# generic_map = {type var in Baz: str} # generic_map = {type var in Baz: str}
generic_type_vars: Tuple = getattr(parent_origin, "__parameters__", tuple()) generic_type_vars: tuple = getattr(parent_origin, "__parameters__", tuple())
generic_map = { generic_map = {
type_var: t for type_var, t in zip(generic_type_vars, get_args(parent)) type_var: t for type_var, t in zip(generic_type_vars, get_args(parent))
} }
@ -1027,10 +1024,10 @@ def _get_all_basemodel_annotations(
def _replace_type_vars( def _replace_type_vars(
type_: Type, type_: type,
generic_map: Optional[Dict[TypeVar, Type]] = None, generic_map: Optional[dict[TypeVar, type]] = None,
default_to_bound: bool = True, default_to_bound: bool = True,
) -> Type: ) -> type:
generic_map = generic_map or {} generic_map = generic_map or {}
if isinstance(type_, TypeVar): if isinstance(type_, TypeVar):
if type_ in generic_map: if type_ in generic_map:
@ -1043,7 +1040,7 @@ def _replace_type_vars(
new_args = tuple( new_args = tuple(
_replace_type_vars(arg, generic_map, default_to_bound) for arg in args _replace_type_vars(arg, generic_map, default_to_bound) for arg in args
) )
return _py_38_safe_origin(origin)[new_args] return _py_38_safe_origin(origin)[new_args] # type: ignore[index]
else: else:
return type_ return type_
@ -1052,5 +1049,5 @@ class BaseToolkit(BaseModel, ABC):
"""Base Toolkit representing a collection of related tools.""" """Base Toolkit representing a collection of related tools."""
@abstractmethod @abstractmethod
def get_tools(self) -> List[BaseTool]: def get_tools(self) -> list[BaseTool]:
"""Get the tools in the toolkit.""" """Get the tools in the toolkit."""

View File

@ -8,7 +8,7 @@ from langchain_core.tools.base import BaseTool
ToolsRenderer = Callable[[List[BaseTool]], str] ToolsRenderer = Callable[[List[BaseTool]], str]
def render_text_description(tools: List[BaseTool]) -> str: def render_text_description(tools: list[BaseTool]) -> str:
"""Render the tool name and description in plain text. """Render the tool name and description in plain text.
Args: Args:
@ -36,7 +36,7 @@ def render_text_description(tools: List[BaseTool]) -> str:
return "\n".join(descriptions) return "\n".join(descriptions)
def render_text_description_and_args(tools: List[BaseTool]) -> str: def render_text_description_and_args(tools: list[BaseTool]) -> str:
"""Render the tool name, description, and args in plain text. """Render the tool name, description, and args in plain text.
Args: Args:

View File

@ -5,10 +5,7 @@ from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Dict,
Optional, Optional,
Tuple,
Type,
Union, Union,
) )
@ -40,7 +37,7 @@ class Tool(BaseTool):
async def ainvoke( async def ainvoke(
self, self,
input: Union[str, Dict, ToolCall], input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
@ -65,7 +62,7 @@ class Tool(BaseTool):
# assume it takes a single string input. # assume it takes a single string input.
return {"tool_input": {"type": "string"}} return {"tool_input": {"type": "string"}}
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]:
"""Convert tool input to pydantic model.""" """Convert tool input to pydantic model."""
args, kwargs = super()._to_args_and_kwargs(tool_input) args, kwargs = super()._to_args_and_kwargs(tool_input)
# For backwards compatibility. The tool must be run with a single input # For backwards compatibility. The tool must be run with a single input
@ -131,7 +128,7 @@ class Tool(BaseTool):
name: str, # We keep these required to support backwards compatibility name: str, # We keep these required to support backwards compatibility
description: str, description: str,
return_direct: bool = False, return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None, args_schema: Optional[type[BaseModel]] = None,
coroutine: Optional[ coroutine: Optional[
Callable[..., Awaitable[Any]] Callable[..., Awaitable[Any]]
] = None, # This is last for compatibility, but should be after func ] = None, # This is last for compatibility, but should be after func

View File

@ -6,11 +6,8 @@ from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Dict,
List,
Literal, Literal,
Optional, Optional,
Type,
Union, Union,
) )
@ -50,7 +47,7 @@ class StructuredTool(BaseTool):
# TODO: Is this needed? # TODO: Is this needed?
async def ainvoke( async def ainvoke(
self, self,
input: Union[str, Dict, ToolCall], input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
@ -112,7 +109,7 @@ class StructuredTool(BaseTool):
name: Optional[str] = None, name: Optional[str] = None,
description: Optional[str] = None, description: Optional[str] = None,
return_direct: bool = False, return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None, args_schema: Optional[type[BaseModel]] = None,
infer_schema: bool = True, infer_schema: bool = True,
*, *,
response_format: Literal["content", "content_and_artifact"] = "content", response_format: Literal["content", "content_and_artifact"] = "content",
@ -209,7 +206,7 @@ class StructuredTool(BaseTool):
) )
def _filter_schema_args(func: Callable) -> List[str]: def _filter_schema_args(func: Callable) -> list[str]:
filter_args = list(FILTERED_ARGS) filter_args = list(FILTERED_ARGS)
if config_param := _get_runnable_config_param(func): if config_param := _get_runnable_config_param(func):
filter_args.append(config_param) filter_args.append(config_param)

View File

@ -8,8 +8,6 @@ from abc import ABC, abstractmethod
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Dict,
List,
Optional, Optional,
Sequence, Sequence,
Union, Union,
@ -52,13 +50,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
def on_chat_model_start( def on_chat_model_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
messages: List[List[BaseMessage]], messages: list[list[BaseMessage]],
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
@ -93,13 +91,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
def on_llm_start( def on_llm_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
prompts: List[str], prompts: list[str],
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
@ -238,13 +236,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
def on_chain_start( def on_chain_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
inputs: Dict[str, Any], inputs: dict[str, Any],
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
run_type: Optional[str] = None, run_type: Optional[str] = None,
name: Optional[str] = None, name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
@ -282,10 +280,10 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
def on_chain_end( def on_chain_end(
self, self,
outputs: Dict[str, Any], outputs: dict[str, Any],
*, *,
run_id: UUID, run_id: UUID,
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
"""End a trace for a chain run. """End a trace for a chain run.
@ -313,7 +311,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
self, self,
error: BaseException, error: BaseException,
*, *,
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[dict[str, Any]] = None,
run_id: UUID, run_id: UUID,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
@ -340,15 +338,15 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
def on_tool_start( def on_tool_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
input_str: str, input_str: str,
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
"""Start a trace for a tool run. """Start a trace for a tool run.
@ -429,13 +427,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
def on_retriever_start( def on_retriever_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
query: str, query: str,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
@ -556,13 +554,13 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
async def on_chat_model_start( async def on_chat_model_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
messages: List[List[BaseMessage]], messages: list[list[BaseMessage]],
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
@ -585,13 +583,13 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
async def on_llm_start( async def on_llm_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
prompts: List[str], prompts: list[str],
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
llm_run = self._create_llm_run( llm_run = self._create_llm_run(
@ -642,7 +640,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
llm_run = self._complete_llm_run( llm_run = self._complete_llm_run(
@ -658,7 +656,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
llm_run = self._errored_llm_run( llm_run = self._errored_llm_run(
@ -670,13 +668,13 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
async def on_chain_start( async def on_chain_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
inputs: Dict[str, Any], inputs: dict[str, Any],
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
run_type: Optional[str] = None, run_type: Optional[str] = None,
name: Optional[str] = None, name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
@ -697,10 +695,10 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
async def on_chain_end( async def on_chain_end(
self, self,
outputs: Dict[str, Any], outputs: dict[str, Any],
*, *,
run_id: UUID, run_id: UUID,
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
chain_run = self._complete_chain_run( chain_run = self._complete_chain_run(
@ -716,7 +714,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
self, self,
error: BaseException, error: BaseException,
*, *,
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[dict[str, Any]] = None,
run_id: UUID, run_id: UUID,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
@ -731,15 +729,15 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
async def on_tool_start( async def on_tool_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
input_str: str, input_str: str,
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
tool_run = self._create_tool_run( tool_run = self._create_tool_run(
@ -776,7 +774,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
tool_run = self._errored_tool_run( tool_run = self._errored_tool_run(
@ -788,13 +786,13 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
async def on_retriever_start( async def on_retriever_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
query: str, query: str,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
@ -819,7 +817,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
retrieval_run = self._errored_retrieval_run( retrieval_run = self._errored_retrieval_run(
@ -839,7 +837,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
retrieval_run = self._complete_retrieval_run( retrieval_run = self._complete_retrieval_run(

View File

@ -6,10 +6,7 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Generator, Generator,
List,
Optional, Optional,
Tuple,
Type,
Union, Union,
cast, cast,
) )
@ -53,7 +50,7 @@ def tracing_v2_enabled(
project_name: Optional[str] = None, project_name: Optional[str] = None,
*, *,
example_id: Optional[Union[str, UUID]] = None, example_id: Optional[Union[str, UUID]] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
client: Optional[LangSmithClient] = None, client: Optional[LangSmithClient] = None,
) -> Generator[LangChainTracer, None, None]: ) -> Generator[LangChainTracer, None, None]:
"""Instruct LangChain to log all runs in context to LangSmith. """Instruct LangChain to log all runs in context to LangSmith.
@ -169,11 +166,11 @@ def _get_tracer_project() -> str:
) )
_configure_hooks: List[ _configure_hooks: list[
Tuple[ tuple[
ContextVar[Optional[BaseCallbackHandler]], ContextVar[Optional[BaseCallbackHandler]],
bool, bool,
Optional[Type[BaseCallbackHandler]], Optional[type[BaseCallbackHandler]],
Optional[str], Optional[str],
] ]
] = [] ] = []
@ -182,7 +179,7 @@ _configure_hooks: List[
def register_configure_hook( def register_configure_hook(
context_var: ContextVar[Optional[Any]], context_var: ContextVar[Optional[Any]],
inheritable: bool, inheritable: bool,
handle_class: Optional[Type[BaseCallbackHandler]] = None, handle_class: Optional[type[BaseCallbackHandler]] = None,
env_var: Optional[str] = None, env_var: Optional[str] = None,
) -> None: ) -> None:
"""Register a configure hook. """Register a configure hook.

View File

@ -11,13 +11,9 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Coroutine, Coroutine,
Dict,
List,
Literal, Literal,
Optional, Optional,
Sequence, Sequence,
Set,
Tuple,
Union, Union,
cast, cast,
) )
@ -81,9 +77,9 @@ class _TracerCore(ABC):
""" """
super().__init__(**kwargs) super().__init__(**kwargs)
self._schema_format = _schema_format # For internal use only API will change. self._schema_format = _schema_format # For internal use only API will change.
self.run_map: Dict[str, Run] = {} self.run_map: dict[str, Run] = {}
"""Map of run ID to run. Cleared on run end.""" """Map of run ID to run. Cleared on run end."""
self.order_map: Dict[UUID, Tuple[UUID, str]] = {} self.order_map: dict[UUID, tuple[UUID, str]] = {}
"""Map of run ID to (trace_id, dotted_order). Cleared when tracer GCed.""" """Map of run ID to (trace_id, dotted_order). Cleared when tracer GCed."""
@abstractmethod @abstractmethod
@ -137,7 +133,7 @@ class _TracerCore(ABC):
self.run_map[str(run.id)] = run self.run_map[str(run.id)] = run
def _get_run( def _get_run(
self, run_id: UUID, run_type: Union[str, Set[str], None] = None self, run_id: UUID, run_type: Union[str, set[str], None] = None
) -> Run: ) -> Run:
try: try:
run = self.run_map[str(run_id)] run = self.run_map[str(run_id)]
@ -145,7 +141,7 @@ class _TracerCore(ABC):
raise TracerException(f"No indexed run ID {run_id}.") from exc raise TracerException(f"No indexed run ID {run_id}.") from exc
if isinstance(run_type, str): if isinstance(run_type, str):
run_types: Union[Set[str], None] = {run_type} run_types: Union[set[str], None] = {run_type}
else: else:
run_types = run_type run_types = run_type
if run_types is not None and run.run_type not in run_types: if run_types is not None and run.run_type not in run_types:
@ -157,12 +153,12 @@ class _TracerCore(ABC):
def _create_chat_model_run( def _create_chat_model_run(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
messages: List[List[BaseMessage]], messages: list[list[BaseMessage]],
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
@ -200,12 +196,12 @@ class _TracerCore(ABC):
def _create_llm_run( def _create_llm_run(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
prompts: List[str], prompts: list[str],
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
@ -239,7 +235,7 @@ class _TracerCore(ABC):
Append token event to LLM run and return the run. Append token event to LLM run and return the run.
""" """
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"}) llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
event_kwargs: Dict[str, Any] = {"token": token} event_kwargs: dict[str, Any] = {"token": token}
if chunk: if chunk:
event_kwargs["chunk"] = chunk event_kwargs["chunk"] = chunk
llm_run.events.append( llm_run.events.append(
@ -258,7 +254,7 @@ class _TracerCore(ABC):
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
llm_run = self._get_run(run_id) llm_run = self._get_run(run_id)
retry_d: Dict[str, Any] = { retry_d: dict[str, Any] = {
"slept": retry_state.idle_for, "slept": retry_state.idle_for,
"attempt": retry_state.attempt_number, "attempt": retry_state.attempt_number,
} }
@ -306,12 +302,12 @@ class _TracerCore(ABC):
def _create_chain_run( def _create_chain_run(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
inputs: Dict[str, Any], inputs: dict[str, Any],
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
run_type: Optional[str] = None, run_type: Optional[str] = None,
name: Optional[str] = None, name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
@ -358,9 +354,9 @@ class _TracerCore(ABC):
def _complete_chain_run( def _complete_chain_run(
self, self,
outputs: Dict[str, Any], outputs: dict[str, Any],
run_id: UUID, run_id: UUID,
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
"""Update a chain run with outputs and end time.""" """Update a chain run with outputs and end time."""
@ -375,7 +371,7 @@ class _TracerCore(ABC):
def _errored_chain_run( def _errored_chain_run(
self, self,
error: BaseException, error: BaseException,
inputs: Optional[Dict[str, Any]], inputs: Optional[dict[str, Any]],
run_id: UUID, run_id: UUID,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
@ -389,14 +385,14 @@ class _TracerCore(ABC):
def _create_tool_run( def _create_tool_run(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
input_str: str, input_str: str,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
"""Create a tool run.""" """Create a tool run."""
@ -428,7 +424,7 @@ class _TracerCore(ABC):
def _complete_tool_run( def _complete_tool_run(
self, self,
output: Dict[str, Any], output: dict[str, Any],
run_id: UUID, run_id: UUID,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
@ -454,12 +450,12 @@ class _TracerCore(ABC):
def _create_retrieval_run( def _create_retrieval_run(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
query: str, query: str,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:

View File

@ -6,7 +6,7 @@ import logging
import threading import threading
import weakref import weakref
from concurrent.futures import Future, ThreadPoolExecutor, wait from concurrent.futures import Future, ThreadPoolExecutor, wait
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast from typing import Any, List, Optional, Sequence, Union, cast
from uuid import UUID from uuid import UUID
import langsmith import langsmith
@ -96,7 +96,7 @@ class EvaluatorCallbackHandler(BaseTracer):
self.futures: weakref.WeakSet[Future] = weakref.WeakSet() self.futures: weakref.WeakSet[Future] = weakref.WeakSet()
self.skip_unfinished = skip_unfinished self.skip_unfinished = skip_unfinished
self.project_name = project_name self.project_name = project_name
self.logged_eval_results: Dict[Tuple[str, str], List[EvaluationResult]] = {} self.logged_eval_results: dict[tuple[str, str], list[EvaluationResult]] = {}
self.lock = threading.Lock() self.lock = threading.Lock()
global _TRACERS global _TRACERS
_TRACERS.add(self) _TRACERS.add(self)
@ -152,7 +152,7 @@ class EvaluatorCallbackHandler(BaseTracer):
def _select_eval_results( def _select_eval_results(
self, self,
results: Union[EvaluationResult, EvaluationResults], results: Union[EvaluationResult, EvaluationResults],
) -> List[EvaluationResult]: ) -> list[EvaluationResult]:
if isinstance(results, EvaluationResult): if isinstance(results, EvaluationResult):
results_ = [results] results_ = [results]
elif isinstance(results, dict) and "results" in results: elif isinstance(results, dict) and "results" in results:
@ -169,10 +169,10 @@ class EvaluatorCallbackHandler(BaseTracer):
evaluator_response: Union[EvaluationResult, EvaluationResults], evaluator_response: Union[EvaluationResult, EvaluationResults],
run: Run, run: Run,
source_run_id: Optional[UUID] = None, source_run_id: Optional[UUID] = None,
) -> List[EvaluationResult]: ) -> list[EvaluationResult]:
results = self._select_eval_results(evaluator_response) results = self._select_eval_results(evaluator_response)
for res in results: for res in results:
source_info_: Dict[str, Any] = {} source_info_: dict[str, Any] = {}
if res.evaluator_info: if res.evaluator_info:
source_info_ = {**res.evaluator_info, **source_info_} source_info_ = {**res.evaluator_info, **source_info_}
run_id_ = getattr(res, "target_run_id", None) run_id_ = getattr(res, "target_run_id", None)

View File

@ -8,7 +8,6 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncIterator, AsyncIterator,
Dict,
Iterator, Iterator,
List, List,
Optional, Optional,
@ -66,14 +65,14 @@ class RunInfo(TypedDict):
""" """
name: str name: str
tags: List[str] tags: list[str]
metadata: Dict[str, Any] metadata: dict[str, Any]
run_type: str run_type: str
inputs: NotRequired[Any] inputs: NotRequired[Any]
parent_run_id: Optional[UUID] parent_run_id: Optional[UUID]
def _assign_name(name: Optional[str], serialized: Optional[Dict[str, Any]]) -> str: def _assign_name(name: Optional[str], serialized: Optional[dict[str, Any]]) -> str:
"""Assign a name to a run.""" """Assign a name to a run."""
if name is not None: if name is not None:
return name return name
@ -107,15 +106,15 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
# Map of run ID to run info. # Map of run ID to run info.
# the entry corresponding to a given run id is cleaned # the entry corresponding to a given run id is cleaned
# up when each corresponding run ends. # up when each corresponding run ends.
self.run_map: Dict[UUID, RunInfo] = {} self.run_map: dict[UUID, RunInfo] = {}
# The callback event that corresponds to the end of a parent run # The callback event that corresponds to the end of a parent run
# may be invoked BEFORE the callback event that corresponds to the end # may be invoked BEFORE the callback event that corresponds to the end
# of a child run, which results in clean up of run_map. # of a child run, which results in clean up of run_map.
# So we keep track of the mapping between children and parent run IDs # So we keep track of the mapping between children and parent run IDs
# in a separate container. This container is GCed when the tracer is GCed. # in a separate container. This container is GCed when the tracer is GCed.
self.parent_map: Dict[UUID, Optional[UUID]] = {} self.parent_map: dict[UUID, Optional[UUID]] = {}
self.is_tapped: Dict[UUID, Any] = {} self.is_tapped: dict[UUID, Any] = {}
# Filter which events will be sent over the queue. # Filter which events will be sent over the queue.
self.root_event_filter = _RootEventFilter( self.root_event_filter = _RootEventFilter(
@ -132,7 +131,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
self.send_stream = memory_stream.get_send_stream() self.send_stream = memory_stream.get_send_stream()
self.receive_stream = memory_stream.get_receive_stream() self.receive_stream = memory_stream.get_receive_stream()
def _get_parent_ids(self, run_id: UUID) -> List[str]: def _get_parent_ids(self, run_id: UUID) -> list[str]:
"""Get the parent IDs of a run (non-recursively) cast to strings.""" """Get the parent IDs of a run (non-recursively) cast to strings."""
parent_ids = [] parent_ids = []
@ -269,8 +268,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
self, self,
run_id: UUID, run_id: UUID,
*, *,
tags: Optional[List[str]], tags: Optional[list[str]],
metadata: Optional[Dict[str, Any]], metadata: Optional[dict[str, Any]],
parent_run_id: Optional[UUID], parent_run_id: Optional[UUID],
name_: str, name_: str,
run_type: str, run_type: str,
@ -296,13 +295,13 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
async def on_chat_model_start( async def on_chat_model_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
messages: List[List[BaseMessage]], messages: list[list[BaseMessage]],
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
@ -337,13 +336,13 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
async def on_llm_start( async def on_llm_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
prompts: List[str], prompts: list[str],
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
@ -384,8 +383,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
data: Any, data: Any,
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Generate a custom astream event.""" """Generate a custom astream event."""
@ -456,7 +455,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
run_info = self.run_map.pop(run_id) run_info = self.run_map.pop(run_id)
inputs_ = run_info["inputs"] inputs_ = run_info["inputs"]
generations: Union[List[List[GenerationChunk]], List[List[ChatGenerationChunk]]] generations: Union[list[list[GenerationChunk]], list[list[ChatGenerationChunk]]]
output: Union[dict, BaseMessage] = {} output: Union[dict, BaseMessage] = {}
if run_info["run_type"] == "chat_model": if run_info["run_type"] == "chat_model":
@ -504,13 +503,13 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
async def on_chain_start( async def on_chain_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
inputs: Dict[str, Any], inputs: dict[str, Any],
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
run_type: Optional[str] = None, run_type: Optional[str] = None,
name: Optional[str] = None, name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
@ -552,10 +551,10 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
async def on_chain_end( async def on_chain_end(
self, self,
outputs: Dict[str, Any], outputs: dict[str, Any],
*, *,
run_id: UUID, run_id: UUID,
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""End a trace for a chain run.""" """End a trace for a chain run."""
@ -586,15 +585,15 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
async def on_tool_start( async def on_tool_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
input_str: str, input_str: str,
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Start a trace for a tool run.""" """Start a trace for a tool run."""
@ -653,13 +652,13 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
async def on_retriever_start( async def on_retriever_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
query: str, query: str,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:

View File

@ -6,7 +6,7 @@ import logging
import warnings import warnings
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import UUID from uuid import UUID
from langsmith import Client from langsmith import Client
@ -91,7 +91,7 @@ class LangChainTracer(BaseTracer):
example_id: Optional[Union[UUID, str]] = None, example_id: Optional[Union[UUID, str]] = None,
project_name: Optional[str] = None, project_name: Optional[str] = None,
client: Optional[Client] = None, client: Optional[Client] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Initialize the LangChain tracer. """Initialize the LangChain tracer.
@ -114,13 +114,13 @@ class LangChainTracer(BaseTracer):
def on_chat_model_start( def on_chat_model_start(
self, self,
serialized: Dict[str, Any], serialized: dict[str, Any],
messages: List[List[BaseMessage]], messages: list[list[BaseMessage]],
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
@ -194,7 +194,7 @@ class LangChainTracer(BaseTracer):
) )
raise ValueError("Failed to get run URL.") raise ValueError("Failed to get run URL.")
def _get_tags(self, run: Run) -> List[str]: def _get_tags(self, run: Run) -> list[str]:
"""Get combined tags for a run.""" """Get combined tags for a run."""
tags = set(run.tags or []) tags = set(run.tags or [])
tags.update(self.tags or []) tags.update(self.tags or [])

View File

@ -7,9 +7,7 @@ from collections import defaultdict
from typing import ( from typing import (
Any, Any,
AsyncIterator, AsyncIterator,
Dict,
Iterator, Iterator,
List,
Literal, Literal,
Optional, Optional,
Sequence, Sequence,
@ -42,16 +40,16 @@ class LogEntry(TypedDict):
"""Name of the object being run.""" """Name of the object being run."""
type: str type: str
"""Type of the object being run, eg. prompt, chain, llm, etc.""" """Type of the object being run, eg. prompt, chain, llm, etc."""
tags: List[str] tags: list[str]
"""List of tags for the run.""" """List of tags for the run."""
metadata: Dict[str, Any] metadata: dict[str, Any]
"""Key-value pairs of metadata for the run.""" """Key-value pairs of metadata for the run."""
start_time: str start_time: str
"""ISO-8601 timestamp of when the run started.""" """ISO-8601 timestamp of when the run started."""
streamed_output_str: List[str] streamed_output_str: list[str]
"""List of LLM tokens streamed by this run, if applicable.""" """List of LLM tokens streamed by this run, if applicable."""
streamed_output: List[Any] streamed_output: list[Any]
"""List of output chunks streamed by this run, if available.""" """List of output chunks streamed by this run, if available."""
inputs: NotRequired[Optional[Any]] inputs: NotRequired[Optional[Any]]
"""Inputs to this run. Not available currently via astream_log.""" """Inputs to this run. Not available currently via astream_log."""
@ -69,7 +67,7 @@ class RunState(TypedDict):
id: str id: str
"""ID of the run.""" """ID of the run."""
streamed_output: List[Any] streamed_output: list[Any]
"""List of output chunks streamed by Runnable.stream()""" """List of output chunks streamed by Runnable.stream()"""
final_output: Optional[Any] final_output: Optional[Any]
"""Final output of the run, usually the result of aggregating (`+`) streamed_output. """Final output of the run, usually the result of aggregating (`+`) streamed_output.
@ -83,7 +81,7 @@ class RunState(TypedDict):
# Do we want tags/metadata on the root run? Client kinda knows it in most situations # Do we want tags/metadata on the root run? Client kinda knows it in most situations
# tags: List[str] # tags: List[str]
logs: Dict[str, LogEntry] logs: dict[str, LogEntry]
"""Map of run names to sub-runs. If filters were supplied, this list will """Map of run names to sub-runs. If filters were supplied, this list will
contain only the runs that matched the filters.""" contain only the runs that matched the filters."""
@ -91,14 +89,14 @@ class RunState(TypedDict):
class RunLogPatch: class RunLogPatch:
"""Patch to the run log.""" """Patch to the run log."""
ops: List[Dict[str, Any]] ops: list[dict[str, Any]]
"""List of jsonpatch operations, which describe how to create the run state """List of jsonpatch operations, which describe how to create the run state
from an empty dict. This is the minimal representation of the log, designed to from an empty dict. This is the minimal representation of the log, designed to
be serialized as JSON and sent over the wire to reconstruct the log on the other be serialized as JSON and sent over the wire to reconstruct the log on the other
side. Reconstruction of the state can be done with any jsonpatch-compliant library, side. Reconstruction of the state can be done with any jsonpatch-compliant library,
see https://jsonpatch.com for more information.""" see https://jsonpatch.com for more information."""
def __init__(self, *ops: Dict[str, Any]) -> None: def __init__(self, *ops: dict[str, Any]) -> None:
self.ops = list(ops) self.ops = list(ops)
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog: def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog:
@ -127,7 +125,7 @@ class RunLog(RunLogPatch):
state: RunState state: RunState
"""Current state of the log, obtained from applying all ops in sequence.""" """Current state of the log, obtained from applying all ops in sequence."""
def __init__(self, *ops: Dict[str, Any], state: RunState) -> None: def __init__(self, *ops: dict[str, Any], state: RunState) -> None:
super().__init__(*ops) super().__init__(*ops)
self.state = state self.state = state
@ -219,14 +217,14 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
self.lock = threading.Lock() self.lock = threading.Lock()
self.send_stream = memory_stream.get_send_stream() self.send_stream = memory_stream.get_send_stream()
self.receive_stream = memory_stream.get_receive_stream() self.receive_stream = memory_stream.get_receive_stream()
self._key_map_by_run_id: Dict[UUID, str] = {} self._key_map_by_run_id: dict[UUID, str] = {}
self._counter_map_by_name: Dict[str, int] = defaultdict(int) self._counter_map_by_name: dict[str, int] = defaultdict(int)
self.root_id: Optional[UUID] = None self.root_id: Optional[UUID] = None
def __aiter__(self) -> AsyncIterator[RunLogPatch]: def __aiter__(self) -> AsyncIterator[RunLogPatch]:
return self.receive_stream.__aiter__() return self.receive_stream.__aiter__()
def send(self, *ops: Dict[str, Any]) -> bool: def send(self, *ops: dict[str, Any]) -> bool:
"""Send a patch to the stream, return False if the stream is closed. """Send a patch to the stream, return False if the stream is closed.
Args: Args:
@ -477,7 +475,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
def _get_standardized_inputs( def _get_standardized_inputs(
run: Run, schema_format: Literal["original", "streaming_events"] run: Run, schema_format: Literal["original", "streaming_events"]
) -> Optional[Dict[str, Any]]: ) -> Optional[dict[str, Any]]:
"""Extract standardized inputs from a run. """Extract standardized inputs from a run.
Standardizes the inputs based on the type of the runnable used. Standardizes the inputs based on the type of the runnable used.
@ -631,7 +629,7 @@ async def _astream_log_implementation(
except TypeError: except TypeError:
prev_final_output = None prev_final_output = None
final_output = chunk final_output = chunk
patches: List[Dict[str, Any]] = [] patches: list[dict[str, Any]] = []
if with_streamed_output_list: if with_streamed_output_list:
patches.append( patches.append(
{ {

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import datetime import datetime
import warnings import warnings
from typing import Any, Dict, List, Optional, Type from typing import Any, Optional
from uuid import UUID from uuid import UUID
from langsmith.schemas import RunBase as BaseRunV2 from langsmith.schemas import RunBase as BaseRunV2
@ -18,7 +18,7 @@ from langchain_core._api import deprecated
@deprecated("0.1.0", alternative="Use string instead.", removal="1.0") @deprecated("0.1.0", alternative="Use string instead.", removal="1.0")
def RunTypeEnum() -> Type[RunTypeEnumDep]: def RunTypeEnum() -> type[RunTypeEnumDep]:
"""RunTypeEnum.""" """RunTypeEnum."""
warnings.warn( warnings.warn(
"RunTypeEnum is deprecated. Please directly use a string instead" "RunTypeEnum is deprecated. Please directly use a string instead"
@ -35,7 +35,7 @@ class TracerSessionV1Base(BaseModelV1):
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow) start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
name: Optional[str] = None name: Optional[str] = None
extra: Optional[Dict[str, Any]] = None extra: Optional[dict[str, Any]] = None
@deprecated("0.1.0", removal="1.0") @deprecated("0.1.0", removal="1.0")
@ -72,10 +72,10 @@ class BaseRun(BaseModelV1):
parent_uuid: Optional[str] = None parent_uuid: Optional[str] = None
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow) start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
end_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow) end_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
extra: Optional[Dict[str, Any]] = None extra: Optional[dict[str, Any]] = None
execution_order: int execution_order: int
child_execution_order: int child_execution_order: int
serialized: Dict[str, Any] serialized: dict[str, Any]
session_id: int session_id: int
error: Optional[str] = None error: Optional[str] = None
@ -84,7 +84,7 @@ class BaseRun(BaseModelV1):
class LLMRun(BaseRun): class LLMRun(BaseRun):
"""Class for LLMRun.""" """Class for LLMRun."""
prompts: List[str] prompts: list[str]
# Temporarily, remove but we will completely remove LLMRun # Temporarily, remove but we will completely remove LLMRun
# response: Optional[LLMResult] = None # response: Optional[LLMResult] = None
@ -93,11 +93,11 @@ class LLMRun(BaseRun):
class ChainRun(BaseRun): class ChainRun(BaseRun):
"""Class for ChainRun.""" """Class for ChainRun."""
inputs: Dict[str, Any] inputs: dict[str, Any]
outputs: Optional[Dict[str, Any]] = None outputs: Optional[dict[str, Any]] = None
child_llm_runs: List[LLMRun] = FieldV1(default_factory=list) child_llm_runs: list[LLMRun] = FieldV1(default_factory=list)
child_chain_runs: List[ChainRun] = FieldV1(default_factory=list) child_chain_runs: list[ChainRun] = FieldV1(default_factory=list)
child_tool_runs: List[ToolRun] = FieldV1(default_factory=list) child_tool_runs: list[ToolRun] = FieldV1(default_factory=list)
@deprecated("0.1.0", alternative="Run", removal="1.0") @deprecated("0.1.0", alternative="Run", removal="1.0")
@ -107,9 +107,9 @@ class ToolRun(BaseRun):
tool_input: str tool_input: str
output: Optional[str] = None output: Optional[str] = None
action: str action: str
child_llm_runs: List[LLMRun] = FieldV1(default_factory=list) child_llm_runs: list[LLMRun] = FieldV1(default_factory=list)
child_chain_runs: List[ChainRun] = FieldV1(default_factory=list) child_chain_runs: list[ChainRun] = FieldV1(default_factory=list)
child_tool_runs: List[ToolRun] = FieldV1(default_factory=list) child_tool_runs: list[ToolRun] = FieldV1(default_factory=list)
# Begin V2 API Schemas # Begin V2 API Schemas
@ -126,9 +126,9 @@ class Run(BaseRunV2):
dotted_order: The dotted order. dotted_order: The dotted order.
""" """
child_runs: List[Run] = FieldV1(default_factory=list) child_runs: list[Run] = FieldV1(default_factory=list)
tags: Optional[List[str]] = FieldV1(default_factory=list) tags: Optional[list[str]] = FieldV1(default_factory=list)
events: List[Dict[str, Any]] = FieldV1(default_factory=list) events: list[dict[str, Any]] = FieldV1(default_factory=list)
trace_id: Optional[UUID] = None trace_id: Optional[UUID] = None
dotted_order: Optional[str] = None dotted_order: Optional[str] = None

View File

@ -1,9 +1,9 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional from typing import Any, Optional
def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]: def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any]:
"""Merge many dicts, handling specific scenarios where a key exists in both """Merge many dicts, handling specific scenarios where a key exists in both
dictionaries but has a value of None in 'left'. In such cases, the method uses the dictionaries but has a value of None in 'left'. In such cases, the method uses the
value from 'right' for that key in the merged dictionary. value from 'right' for that key in the merged dictionary.
@ -69,7 +69,7 @@ def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]
return merged return merged
def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]: def merge_lists(left: Optional[list], *others: Optional[list]) -> Optional[list]:
"""Add many lists, handling None. """Add many lists, handling None.
Args: Args:

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import os import os
from typing import Any, Dict, List, Optional, Union from typing import Any, Optional, Union
def env_var_is_set(env_var: str) -> bool: def env_var_is_set(env_var: str) -> bool:
@ -22,8 +22,8 @@ def env_var_is_set(env_var: str) -> bool:
def get_from_dict_or_env( def get_from_dict_or_env(
data: Dict[str, Any], data: dict[str, Any],
key: Union[str, List[str]], key: Union[str, list[str]],
env_key: str, env_key: str,
default: Optional[str] = None, default: Optional[str] = None,
) -> str: ) -> str:

View File

@ -85,7 +85,7 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict:
removal="1.0", removal="1.0",
) )
def convert_pydantic_to_openai_function( def convert_pydantic_to_openai_function(
model: Type, model: type,
*, *,
name: Optional[str] = None, name: Optional[str] = None,
description: Optional[str] = None, description: Optional[str] = None,
@ -106,8 +106,10 @@ def convert_pydantic_to_openai_function(
""" """
if hasattr(model, "model_json_schema"): if hasattr(model, "model_json_schema"):
schema = model.model_json_schema() # Pydantic 2 schema = model.model_json_schema() # Pydantic 2
else: elif hasattr(model, "schema"):
schema = model.schema() # Pydantic 1 schema = model.schema() # Pydantic 1
else:
raise TypeError("Model must be a Pydantic model.")
schema = dereference_refs(schema) schema = dereference_refs(schema)
if "definitions" in schema: # pydantic 1 if "definitions" in schema: # pydantic 1
schema.pop("definitions", None) schema.pop("definitions", None)
@ -128,7 +130,7 @@ def convert_pydantic_to_openai_function(
removal="1.0", removal="1.0",
) )
def convert_pydantic_to_openai_tool( def convert_pydantic_to_openai_tool(
model: Type[BaseModel], model: type[BaseModel],
*, *,
name: Optional[str] = None, name: Optional[str] = None,
description: Optional[str] = None, description: Optional[str] = None,
@ -194,8 +196,8 @@ def convert_python_function_to_openai_function(
) )
def _convert_typed_dict_to_openai_function(typed_dict: Type) -> FunctionDescription: def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescription:
visited: Dict = {} visited: dict = {}
from pydantic.v1 import BaseModel from pydantic.v1 import BaseModel
model = cast( model = cast(
@ -209,11 +211,11 @@ _MAX_TYPED_DICT_RECURSION = 25
def _convert_any_typed_dicts_to_pydantic( def _convert_any_typed_dicts_to_pydantic(
type_: Type, type_: type,
*, *,
visited: Dict, visited: dict,
depth: int = 0, depth: int = 0,
) -> Type: ) -> type:
from pydantic.v1 import Field as Field_v1 from pydantic.v1 import Field as Field_v1
from pydantic.v1 import create_model as create_model_v1 from pydantic.v1 import create_model as create_model_v1
@ -267,9 +269,9 @@ def _convert_any_typed_dicts_to_pydantic(
subscriptable_origin = _py_38_safe_origin(origin) subscriptable_origin = _py_38_safe_origin(origin)
type_args = tuple( type_args = tuple(
_convert_any_typed_dicts_to_pydantic(arg, depth=depth + 1, visited=visited) _convert_any_typed_dicts_to_pydantic(arg, depth=depth + 1, visited=visited)
for arg in type_args for arg in type_args # type: ignore[index]
) )
return subscriptable_origin[type_args] return subscriptable_origin[type_args] # type: ignore[index]
else: else:
return type_ return type_
@ -333,10 +335,10 @@ def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
def convert_to_openai_function( def convert_to_openai_function(
function: Union[Dict[str, Any], Type, Callable, BaseTool], function: Union[dict[str, Any], type, Callable, BaseTool],
*, *,
strict: Optional[bool] = None, strict: Optional[bool] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Convert a raw function/class to an OpenAI function. """Convert a raw function/class to an OpenAI function.
.. versionchanged:: 0.2.29 .. versionchanged:: 0.2.29
@ -411,10 +413,10 @@ def convert_to_openai_function(
def convert_to_openai_tool( def convert_to_openai_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], tool: Union[dict[str, Any], type[BaseModel], Callable, BaseTool],
*, *,
strict: Optional[bool] = None, strict: Optional[bool] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Convert a raw function/class to an OpenAI tool. """Convert a raw function/class to an OpenAI tool.
.. versionchanged:: 0.2.29 .. versionchanged:: 0.2.29
@ -441,13 +443,13 @@ def convert_to_openai_tool(
if isinstance(tool, dict) and tool.get("type") == "function" and "function" in tool: if isinstance(tool, dict) and tool.get("type") == "function" and "function" in tool:
return tool return tool
oai_function = convert_to_openai_function(tool, strict=strict) oai_function = convert_to_openai_function(tool, strict=strict)
oai_tool: Dict[str, Any] = {"type": "function", "function": oai_function} oai_tool: dict[str, Any] = {"type": "function", "function": oai_function}
return oai_tool return oai_tool
def tool_example_to_messages( def tool_example_to_messages(
input: str, tool_calls: List[BaseModel], tool_outputs: Optional[List[str]] = None input: str, tool_calls: list[BaseModel], tool_outputs: Optional[list[str]] = None
) -> List[BaseMessage]: ) -> list[BaseMessage]:
"""Convert an example into a list of messages that can be fed into an LLM. """Convert an example into a list of messages that can be fed into an LLM.
This code is an adapter that converts a single example to a list of messages This code is an adapter that converts a single example to a list of messages
@ -511,7 +513,7 @@ def tool_example_to_messages(
tool_example_to_messages(txt, [tool_call]) tool_example_to_messages(txt, [tool_call])
) )
""" """
messages: List[BaseMessage] = [HumanMessage(content=input)] messages: list[BaseMessage] = [HumanMessage(content=input)]
openai_tool_calls = [] openai_tool_calls = []
for tool_call in tool_calls: for tool_call in tool_calls:
openai_tool_calls.append( openai_tool_calls.append(
@ -540,10 +542,10 @@ def tool_example_to_messages(
def _parse_google_docstring( def _parse_google_docstring(
docstring: Optional[str], docstring: Optional[str],
args: List[str], args: list[str],
*, *,
error_on_invalid_docstring: bool = False, error_on_invalid_docstring: bool = False,
) -> Tuple[str, dict]: ) -> tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function. """Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide. Assumes the function docstring follows Google Python style guide.
@ -590,12 +592,12 @@ def _parse_google_docstring(
return description, arg_descriptions return description, arg_descriptions
def _py_38_safe_origin(origin: Type) -> Type: def _py_38_safe_origin(origin: type) -> type:
origin_union_type_map: Dict[Type, Any] = ( origin_union_type_map: dict[type, Any] = (
{types.UnionType: Union} if hasattr(types, "UnionType") else {} {types.UnionType: Union} if hasattr(types, "UnionType") else {}
) )
origin_map: Dict[Type, Any] = { origin_map: dict[type, Any] = {
dict: Dict, dict: Dict,
list: List, list: List,
tuple: Tuple, tuple: Tuple,
@ -610,8 +612,8 @@ def _py_38_safe_origin(origin: Type) -> Type:
def _recursive_set_additional_properties_false( def _recursive_set_additional_properties_false(
schema: Dict[str, Any], schema: dict[str, Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
if isinstance(schema, dict): if isinstance(schema, dict):
# Check if 'required' is a key at the current level or if the schema is empty, # Check if 'required' is a key at the current level or if the schema is empty,
# in which case additionalProperties still needs to be specified. # in which case additionalProperties still needs to be specified.

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import json import json
import re import re
from typing import Any, Callable, List from typing import Any, Callable
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
@ -163,7 +163,7 @@ def _parse_json(
return parser(json_str) return parser(json_str)
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict: def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict:
""" """
Parse a JSON string from a Markdown string and check that it Parse a JSON string from a Markdown string and check that it
contains the expected keys. contains the expected keys.

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, List, Optional, Sequence, Set from typing import Any, Optional, Sequence
def _retrieve_ref(path: str, schema: dict) -> dict: def _retrieve_ref(path: str, schema: dict) -> dict:
@ -24,9 +24,9 @@ def _retrieve_ref(path: str, schema: dict) -> dict:
def _dereference_refs_helper( def _dereference_refs_helper(
obj: Any, obj: Any,
full_schema: Dict[str, Any], full_schema: dict[str, Any],
skip_keys: Sequence[str], skip_keys: Sequence[str],
processed_refs: Optional[Set[str]] = None, processed_refs: Optional[set[str]] = None,
) -> Any: ) -> Any:
if processed_refs is None: if processed_refs is None:
processed_refs = set() processed_refs = set()
@ -63,8 +63,8 @@ def _dereference_refs_helper(
def _infer_skip_keys( def _infer_skip_keys(
obj: Any, full_schema: dict, processed_refs: Optional[Set[str]] = None obj: Any, full_schema: dict, processed_refs: Optional[set[str]] = None
) -> List[str]: ) -> list[str]:
if processed_refs is None: if processed_refs is None:
processed_refs = set() processed_refs = set()

View File

@ -16,7 +16,6 @@ from typing import (
Mapping, Mapping,
Optional, Optional,
Sequence, Sequence,
Tuple,
Union, Union,
cast, cast,
) )
@ -45,7 +44,7 @@ class ChevronError(SyntaxError):
# #
def grab_literal(template: str, l_del: str) -> Tuple[str, str]: def grab_literal(template: str, l_del: str) -> tuple[str, str]:
"""Parse a literal from the template. """Parse a literal from the template.
Args: Args:
@ -124,7 +123,7 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
return False return False
def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], str]: def parse_tag(template: str, l_del: str, r_del: str) -> tuple[tuple[str, str], str]:
"""Parse a tag from a template. """Parse a tag from a template.
Args: Args:
@ -201,7 +200,7 @@ def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], s
def tokenize( def tokenize(
template: str, def_ldel: str = "{{", def_rdel: str = "}}" template: str, def_ldel: str = "{{", def_rdel: str = "}}"
) -> Iterator[Tuple[str, str]]: ) -> Iterator[tuple[str, str]]:
"""Tokenize a mustache template. """Tokenize a mustache template.
Tokenizes a mustache template in a generator fashion, Tokenizes a mustache template in a generator fashion,
@ -427,13 +426,13 @@ def _get_partial(name: str, partials_dict: Mapping[str, str]) -> str:
# #
# The main rendering function # The main rendering function
# #
g_token_cache: Dict[str, List[Tuple[str, str]]] = {} g_token_cache: dict[str, list[tuple[str, str]]] = {}
EMPTY_DICT: MappingProxyType[str, str] = MappingProxyType({}) EMPTY_DICT: MappingProxyType[str, str] = MappingProxyType({})
def render( def render(
template: Union[str, List[Tuple[str, str]]] = "", template: Union[str, list[tuple[str, str]]] = "",
data: Mapping[str, Any] = EMPTY_DICT, data: Mapping[str, Any] = EMPTY_DICT,
partials_dict: Mapping[str, str] = EMPTY_DICT, partials_dict: Mapping[str, str] = EMPTY_DICT,
padding: str = "", padding: str = "",
@ -490,7 +489,7 @@ def render(
if isinstance(template, Sequence) and not isinstance(template, str): if isinstance(template, Sequence) and not isinstance(template, str):
# Then we don't need to tokenize it # Then we don't need to tokenize it
# But it does need to be a generator # But it does need to be a generator
tokens: Iterator[Tuple[str, str]] = (token for token in template) tokens: Iterator[tuple[str, str]] = (token for token in template)
else: else:
if template in g_token_cache: if template in g_token_cache:
tokens = (token for token in g_token_cache[template]) tokens = (token for token in g_token_cache[template])
@ -561,7 +560,7 @@ def render(
if callable(scope): if callable(scope):
# Generate template text from tags # Generate template text from tags
text = "" text = ""
tags: List[Tuple[str, str]] = [] tags: list[tuple[str, str]] = []
for token in tokens: for token in tokens:
if token == ("end", key): if token == ("end", key):
break break

View File

@ -11,7 +11,6 @@ from typing import (
Any, Any,
Callable, Callable,
Dict, Dict,
List,
Optional, Optional,
Type, Type,
TypeVar, TypeVar,
@ -71,7 +70,7 @@ else:
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel) TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
def is_pydantic_v1_subclass(cls: Type) -> bool: def is_pydantic_v1_subclass(cls: type) -> bool:
"""Check if the installed Pydantic version is 1.x-like.""" """Check if the installed Pydantic version is 1.x-like."""
if PYDANTIC_MAJOR_VERSION == 1: if PYDANTIC_MAJOR_VERSION == 1:
return True return True
@ -83,14 +82,14 @@ def is_pydantic_v1_subclass(cls: Type) -> bool:
return False return False
def is_pydantic_v2_subclass(cls: Type) -> bool: def is_pydantic_v2_subclass(cls: type) -> bool:
"""Check if the installed Pydantic version is 1.x-like.""" """Check if the installed Pydantic version is 1.x-like."""
from pydantic import BaseModel from pydantic import BaseModel
return PYDANTIC_MAJOR_VERSION == 2 and issubclass(cls, BaseModel) return PYDANTIC_MAJOR_VERSION == 2 and issubclass(cls, BaseModel)
def is_basemodel_subclass(cls: Type) -> bool: def is_basemodel_subclass(cls: type) -> bool:
"""Check if the given class is a subclass of Pydantic BaseModel. """Check if the given class is a subclass of Pydantic BaseModel.
Check if the given class is a subclass of any of the following: Check if the given class is a subclass of any of the following:
@ -166,7 +165,7 @@ def pre_init(func: Callable) -> Any:
@root_validator(pre=True) @root_validator(pre=True)
@wraps(func) @wraps(func)
def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]: def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> dict[str, Any]:
"""Decorator to run a function before model initialization. """Decorator to run a function before model initialization.
Args: Args:
@ -218,12 +217,12 @@ class _IgnoreUnserializable(GenerateJsonSchema):
def _create_subset_model_v1( def _create_subset_model_v1(
name: str, name: str,
model: Type[BaseModel], model: type[BaseModel],
field_names: list, field_names: list,
*, *,
descriptions: Optional[dict] = None, descriptions: Optional[dict] = None,
fn_description: Optional[str] = None, fn_description: Optional[str] = None,
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields.""" """Create a pydantic model with only a subset of model's fields."""
if PYDANTIC_MAJOR_VERSION == 1: if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import create_model from pydantic import create_model
@ -256,12 +255,12 @@ def _create_subset_model_v1(
def _create_subset_model_v2( def _create_subset_model_v2(
name: str, name: str,
model: Type[pydantic.BaseModel], model: type[pydantic.BaseModel],
field_names: List[str], field_names: list[str],
*, *,
descriptions: Optional[dict] = None, descriptions: Optional[dict] = None,
fn_description: Optional[str] = None, fn_description: Optional[str] = None,
) -> Type[pydantic.BaseModel]: ) -> type[pydantic.BaseModel]:
"""Create a pydantic model with a subset of the model fields.""" """Create a pydantic model with a subset of the model fields."""
from pydantic import create_model from pydantic import create_model
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
@ -299,11 +298,11 @@ def _create_subset_model_v2(
def _create_subset_model( def _create_subset_model(
name: str, name: str,
model: TypeBaseModel, model: TypeBaseModel,
field_names: List[str], field_names: list[str],
*, *,
descriptions: Optional[dict] = None, descriptions: Optional[dict] = None,
fn_description: Optional[str] = None, fn_description: Optional[str] = None,
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""Create subset model using the same pydantic version as the input model.""" """Create subset model using the same pydantic version as the input model."""
if PYDANTIC_MAJOR_VERSION == 1: if PYDANTIC_MAJOR_VERSION == 1:
return _create_subset_model_v1( return _create_subset_model_v1(
@ -344,25 +343,25 @@ if PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import BaseModel as BaseModelV1
@overload @overload
def get_fields(model: Type[BaseModelV2]) -> Dict[str, FieldInfoV2]: ... def get_fields(model: type[BaseModelV2]) -> dict[str, FieldInfoV2]: ...
@overload @overload
def get_fields(model: BaseModelV2) -> Dict[str, FieldInfoV2]: ... def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ...
@overload @overload
def get_fields(model: Type[BaseModelV1]) -> Dict[str, FieldInfoV1]: ... def get_fields(model: type[BaseModelV1]) -> dict[str, FieldInfoV1]: ...
@overload @overload
def get_fields(model: BaseModelV1) -> Dict[str, FieldInfoV1]: ... def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...
def get_fields( def get_fields(
model: Union[ model: Union[
BaseModelV2, BaseModelV2,
BaseModelV1, BaseModelV1,
Type[BaseModelV2], type[BaseModelV2],
Type[BaseModelV1], type[BaseModelV1],
], ],
) -> Union[Dict[str, FieldInfoV2], Dict[str, FieldInfoV1]]: ) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
"""Get the field names of a Pydantic model.""" """Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"): if hasattr(model, "model_fields"):
return model.model_fields # type: ignore return model.model_fields # type: ignore
@ -375,8 +374,8 @@ elif PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1_ from pydantic import BaseModel as BaseModelV1_
def get_fields( # type: ignore[no-redef] def get_fields( # type: ignore[no-redef]
model: Union[Type[BaseModelV1_], BaseModelV1_], model: Union[type[BaseModelV1_], BaseModelV1_],
) -> Dict[str, FieldInfoV1]: ) -> dict[str, FieldInfoV1]:
"""Get the field names of a Pydantic model.""" """Get the field names of a Pydantic model."""
return model.__fields__ # type: ignore return model.__fields__ # type: ignore
else: else:
@ -394,14 +393,14 @@ def _create_root_model(
type_: Any, type_: Any,
module_name: Optional[str] = None, module_name: Optional[str] = None,
default_: object = NO_DEFAULT, default_: object = NO_DEFAULT,
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""Create a base class.""" """Create a base class."""
def schema( def schema(
cls: Type[BaseModel], cls: type[BaseModel],
by_alias: bool = True, by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE, ref_template: str = DEFAULT_REF_TEMPLATE,
) -> Dict[str, Any]: ) -> dict[str, Any]:
# Complains about schema not being defined in superclass # Complains about schema not being defined in superclass
schema_ = super(cls, cls).schema( # type: ignore[misc] schema_ = super(cls, cls).schema( # type: ignore[misc]
by_alias=by_alias, ref_template=ref_template by_alias=by_alias, ref_template=ref_template
@ -410,12 +409,12 @@ def _create_root_model(
return schema_ return schema_
def model_json_schema( def model_json_schema(
cls: Type[BaseModel], cls: type[BaseModel],
by_alias: bool = True, by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE, ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema, schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
mode: JsonSchemaMode = "validation", mode: JsonSchemaMode = "validation",
) -> Dict[str, Any]: ) -> dict[str, Any]:
# Complains about model_json_schema not being defined in superclass # Complains about model_json_schema not being defined in superclass
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc] schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
by_alias=by_alias, by_alias=by_alias,
@ -452,7 +451,7 @@ def _create_root_model_cached(
*, *,
module_name: Optional[str] = None, module_name: Optional[str] = None,
default_: object = NO_DEFAULT, default_: object = NO_DEFAULT,
) -> Type[BaseModel]: ) -> type[BaseModel]:
return _create_root_model( return _create_root_model(
model_name, type_, default_=default_, module_name=module_name model_name, type_, default_=default_, module_name=module_name
) )
@ -462,7 +461,7 @@ def _create_root_model_cached(
def _create_model_cached( def _create_model_cached(
__model_name: str, __model_name: str,
**field_definitions: Any, **field_definitions: Any,
) -> Type[BaseModel]: ) -> type[BaseModel]:
return _create_model_base( return _create_model_base(
__model_name, __model_name,
__config__=_SchemaConfig, __config__=_SchemaConfig,
@ -474,7 +473,7 @@ def create_model(
__model_name: str, __model_name: str,
__module_name: Optional[str] = None, __module_name: Optional[str] = None,
**field_definitions: Any, **field_definitions: Any,
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""Create a pydantic model with the given field definitions. """Create a pydantic model with the given field definitions.
Please use create_model_v2 instead of this function. Please use create_model_v2 instead of this function.
@ -513,7 +512,7 @@ def create_model(
_RESERVED_NAMES = {key for key in dir(BaseModel) if not key.startswith("_")} _RESERVED_NAMES = {key for key in dir(BaseModel) if not key.startswith("_")}
def _remap_field_definitions(field_definitions: Dict[str, Any]) -> Dict[str, Any]: def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]:
"""This remaps fields to avoid colliding with internal pydantic fields.""" """This remaps fields to avoid colliding with internal pydantic fields."""
from pydantic import Field from pydantic import Field
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
@ -547,9 +546,9 @@ def create_model_v2(
model_name: str, model_name: str,
*, *,
module_name: Optional[str] = None, module_name: Optional[str] = None,
field_definitions: Optional[Dict[str, Any]] = None, field_definitions: Optional[dict[str, Any]] = None,
root: Optional[Any] = None, root: Optional[Any] = None,
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""Create a pydantic model with the given field definitions. """Create a pydantic model with the given field definitions.
Attention: Attention:

View File

@ -32,13 +32,9 @@ from typing import (
Callable, Callable,
ClassVar, ClassVar,
Collection, Collection,
Dict,
Iterable, Iterable,
List,
Optional, Optional,
Sequence, Sequence,
Tuple,
Type,
TypeVar, TypeVar,
) )
@ -66,13 +62,13 @@ class VectorStore(ABC):
def add_texts( def add_texts(
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[list[dict]] = None,
# One of the kwargs should be `ids` which is a list of ids # One of the kwargs should be `ids` which is a list of ids
# associated with the texts. # associated with the texts.
# This is not yet enforced in the type signature for backwards compatibility # This is not yet enforced in the type signature for backwards compatibility
# with existing implementations. # with existing implementations.
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> list[str]:
"""Run more texts through the embeddings and add to the vectorstore. """Run more texts through the embeddings and add to the vectorstore.
Args: Args:
@ -124,7 +120,7 @@ class VectorStore(ABC):
) )
return None return None
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete by vector ID or other criteria. """Delete by vector ID or other criteria.
Args: Args:
@ -138,7 +134,7 @@ class VectorStore(ABC):
raise NotImplementedError("delete method must be implemented by subclass.") raise NotImplementedError("delete method must be implemented by subclass.")
def get_by_ids(self, ids: Sequence[str], /) -> List[Document]: def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
"""Get documents by their IDs. """Get documents by their IDs.
The returned documents are expected to have the ID field set to the ID of the The returned documents are expected to have the ID field set to the ID of the
@ -167,7 +163,7 @@ class VectorStore(ABC):
) )
# Implementations should override this method to provide an async native version. # Implementations should override this method to provide an async native version.
async def aget_by_ids(self, ids: Sequence[str], /) -> List[Document]: async def aget_by_ids(self, ids: Sequence[str], /) -> list[Document]:
"""Async get documents by their IDs. """Async get documents by their IDs.
The returned documents are expected to have the ID field set to the ID of the The returned documents are expected to have the ID field set to the ID of the
@ -194,7 +190,7 @@ class VectorStore(ABC):
return await run_in_executor(None, self.get_by_ids, ids) return await run_in_executor(None, self.get_by_ids, ids)
async def adelete( async def adelete(
self, ids: Optional[List[str]] = None, **kwargs: Any self, ids: Optional[list[str]] = None, **kwargs: Any
) -> Optional[bool]: ) -> Optional[bool]:
"""Async delete by vector ID or other criteria. """Async delete by vector ID or other criteria.
@ -211,9 +207,9 @@ class VectorStore(ABC):
async def aadd_texts( async def aadd_texts(
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[list[dict]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> list[str]:
"""Async run more texts through the embeddings and add to the vectorstore. """Async run more texts through the embeddings and add to the vectorstore.
Args: Args:
@ -254,7 +250,7 @@ class VectorStore(ABC):
return await self.aadd_documents(docs, **kwargs) return await self.aadd_documents(docs, **kwargs)
return await run_in_executor(None, self.add_texts, texts, metadatas, **kwargs) return await run_in_executor(None, self.add_texts, texts, metadatas, **kwargs)
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: def add_documents(self, documents: list[Document], **kwargs: Any) -> list[str]:
"""Add or update documents in the vectorstore. """Add or update documents in the vectorstore.
Args: Args:
@ -287,8 +283,8 @@ class VectorStore(ABC):
) )
async def aadd_documents( async def aadd_documents(
self, documents: List[Document], **kwargs: Any self, documents: list[Document], **kwargs: Any
) -> List[str]: ) -> list[str]:
"""Async run more documents through the embeddings and add to """Async run more documents through the embeddings and add to
the vectorstore. the vectorstore.
@ -318,7 +314,7 @@ class VectorStore(ABC):
return await run_in_executor(None, self.add_documents, documents, **kwargs) return await run_in_executor(None, self.add_documents, documents, **kwargs)
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]: def search(self, query: str, search_type: str, **kwargs: Any) -> list[Document]:
"""Return docs most similar to query using a specified search type. """Return docs most similar to query using a specified search type.
Args: Args:
@ -352,7 +348,7 @@ class VectorStore(ABC):
async def asearch( async def asearch(
self, query: str, search_type: str, **kwargs: Any self, query: str, search_type: str, **kwargs: Any
) -> List[Document]: ) -> list[Document]:
"""Async return docs most similar to query using a specified search type. """Async return docs most similar to query using a specified search type.
Args: Args:
@ -386,7 +382,7 @@ class VectorStore(ABC):
@abstractmethod @abstractmethod
def similarity_search( def similarity_search(
self, query: str, k: int = 4, **kwargs: Any self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]: ) -> list[Document]:
"""Return docs most similar to query. """Return docs most similar to query.
Args: Args:
@ -442,7 +438,7 @@ class VectorStore(ABC):
def similarity_search_with_score( def similarity_search_with_score(
self, *args: Any, **kwargs: Any self, *args: Any, **kwargs: Any
) -> List[Tuple[Document, float]]: ) -> list[tuple[Document, float]]:
"""Run similarity search with distance. """Run similarity search with distance.
Args: Args:
@ -456,7 +452,7 @@ class VectorStore(ABC):
async def asimilarity_search_with_score( async def asimilarity_search_with_score(
self, *args: Any, **kwargs: Any self, *args: Any, **kwargs: Any
) -> List[Tuple[Document, float]]: ) -> list[tuple[Document, float]]:
"""Async run similarity search with distance. """Async run similarity search with distance.
Args: Args:
@ -479,7 +475,7 @@ class VectorStore(ABC):
query: str, query: str,
k: int = 4, k: int = 4,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> list[tuple[Document, float]]:
""" """
Default similarity search with relevance scores. Modify if necessary Default similarity search with relevance scores. Modify if necessary
in subclass. in subclass.
@ -506,7 +502,7 @@ class VectorStore(ABC):
query: str, query: str,
k: int = 4, k: int = 4,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> list[tuple[Document, float]]:
""" """
Default similarity search with relevance scores. Modify if necessary Default similarity search with relevance scores. Modify if necessary
in subclass. in subclass.
@ -533,7 +529,7 @@ class VectorStore(ABC):
query: str, query: str,
k: int = 4, k: int = 4,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> list[tuple[Document, float]]:
"""Return docs and relevance scores in the range [0, 1]. """Return docs and relevance scores in the range [0, 1].
0 is dissimilar, 1 is most similar. 0 is dissimilar, 1 is most similar.
@ -581,7 +577,7 @@ class VectorStore(ABC):
query: str, query: str,
k: int = 4, k: int = 4,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> list[tuple[Document, float]]:
"""Async return docs and relevance scores in the range [0, 1]. """Async return docs and relevance scores in the range [0, 1].
0 is dissimilar, 1 is most similar. 0 is dissimilar, 1 is most similar.
@ -626,7 +622,7 @@ class VectorStore(ABC):
async def asimilarity_search( async def asimilarity_search(
self, query: str, k: int = 4, **kwargs: Any self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]: ) -> list[Document]:
"""Async return docs most similar to query. """Async return docs most similar to query.
Args: Args:
@ -644,8 +640,8 @@ class VectorStore(ABC):
return await run_in_executor(None, self.similarity_search, query, k=k, **kwargs) return await run_in_executor(None, self.similarity_search, query, k=k, **kwargs)
def similarity_search_by_vector( def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any self, embedding: list[float], k: int = 4, **kwargs: Any
) -> List[Document]: ) -> list[Document]:
"""Return docs most similar to embedding vector. """Return docs most similar to embedding vector.
Args: Args:
@ -659,8 +655,8 @@ class VectorStore(ABC):
raise NotImplementedError raise NotImplementedError
async def asimilarity_search_by_vector( async def asimilarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any self, embedding: list[float], k: int = 4, **kwargs: Any
) -> List[Document]: ) -> list[Document]:
"""Async return docs most similar to embedding vector. """Async return docs most similar to embedding vector.
Args: Args:
@ -686,7 +682,7 @@ class VectorStore(ABC):
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity
@ -715,7 +711,7 @@ class VectorStore(ABC):
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
"""Async return docs selected using the maximal marginal relevance. """Async return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity
@ -750,12 +746,12 @@ class VectorStore(ABC):
def max_marginal_relevance_search_by_vector( def max_marginal_relevance_search_by_vector(
self, self,
embedding: List[float], embedding: list[float],
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity
@ -779,12 +775,12 @@ class VectorStore(ABC):
async def amax_marginal_relevance_search_by_vector( async def amax_marginal_relevance_search_by_vector(
self, self,
embedding: List[float], embedding: list[float],
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
"""Async return docs selected using the maximal marginal relevance. """Async return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity
@ -816,8 +812,8 @@ class VectorStore(ABC):
@classmethod @classmethod
def from_documents( def from_documents(
cls: Type[VST], cls: type[VST],
documents: List[Document], documents: list[Document],
embedding: Embeddings, embedding: Embeddings,
**kwargs: Any, **kwargs: Any,
) -> VST: ) -> VST:
@ -837,8 +833,8 @@ class VectorStore(ABC):
@classmethod @classmethod
async def afrom_documents( async def afrom_documents(
cls: Type[VST], cls: type[VST],
documents: List[Document], documents: list[Document],
embedding: Embeddings, embedding: Embeddings,
**kwargs: Any, **kwargs: Any,
) -> VST: ) -> VST:
@ -859,10 +855,10 @@ class VectorStore(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def from_texts( def from_texts(
cls: Type[VST], cls: type[VST],
texts: List[str], texts: list[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[list[dict]] = None,
**kwargs: Any, **kwargs: Any,
) -> VST: ) -> VST:
"""Return VectorStore initialized from texts and embeddings. """Return VectorStore initialized from texts and embeddings.
@ -880,10 +876,10 @@ class VectorStore(ABC):
@classmethod @classmethod
async def afrom_texts( async def afrom_texts(
cls: Type[VST], cls: type[VST],
texts: List[str], texts: list[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[list[dict]] = None,
**kwargs: Any, **kwargs: Any,
) -> VST: ) -> VST:
"""Async return VectorStore initialized from texts and embeddings. """Async return VectorStore initialized from texts and embeddings.
@ -902,7 +898,7 @@ class VectorStore(ABC):
None, cls.from_texts, texts, embedding, metadatas, **kwargs None, cls.from_texts, texts, embedding, metadatas, **kwargs
) )
def _get_retriever_tags(self) -> List[str]: def _get_retriever_tags(self) -> list[str]:
"""Get tags for retriever.""" """Get tags for retriever."""
tags = [self.__class__.__name__] tags = [self.__class__.__name__]
if self.embeddings: if self.embeddings:
@ -991,7 +987,7 @@ class VectorStoreRetriever(BaseRetriever):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_search_type(cls, values: Dict) -> Any: def validate_search_type(cls, values: dict) -> Any:
"""Validate search type. """Validate search type.
Args: Args:
@ -1040,7 +1036,7 @@ class VectorStoreRetriever(BaseRetriever):
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]: ) -> list[Document]:
if self.search_type == "similarity": if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs) docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
elif self.search_type == "similarity_score_threshold": elif self.search_type == "similarity_score_threshold":
@ -1060,7 +1056,7 @@ class VectorStoreRetriever(BaseRetriever):
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]: ) -> list[Document]:
if self.search_type == "similarity": if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search( docs = await self.vectorstore.asimilarity_search(
query, **self.search_kwargs query, **self.search_kwargs
@ -1080,7 +1076,7 @@ class VectorStoreRetriever(BaseRetriever):
raise ValueError(f"search_type of {self.search_type} not allowed.") raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs return docs
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: def add_documents(self, documents: list[Document], **kwargs: Any) -> list[str]:
"""Add documents to the vectorstore. """Add documents to the vectorstore.
Args: Args:
@ -1093,8 +1089,8 @@ class VectorStoreRetriever(BaseRetriever):
return self.vectorstore.add_documents(documents, **kwargs) return self.vectorstore.add_documents(documents, **kwargs)
async def aadd_documents( async def aadd_documents(
self, documents: List[Document], **kwargs: Any self, documents: list[Document], **kwargs: Any
) -> List[str]: ) -> list[str]:
"""Async add documents to the vectorstore. """Async add documents to the vectorstore.
Args: Args:

View File

@ -7,12 +7,9 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable, Callable,
Dict,
Iterator, Iterator,
List,
Optional, Optional,
Sequence, Sequence,
Tuple,
) )
from langchain_core._api import deprecated from langchain_core._api import deprecated
@ -153,7 +150,7 @@ class InMemoryVectorStore(VectorStore):
""" """
# TODO: would be nice to change to # TODO: would be nice to change to
# Dict[str, Document] at some point (will be a breaking change) # Dict[str, Document] at some point (will be a breaking change)
self.store: Dict[str, Dict[str, Any]] = {} self.store: dict[str, dict[str, Any]] = {}
self.embedding = embedding self.embedding = embedding
@property @property
@ -170,10 +167,10 @@ class InMemoryVectorStore(VectorStore):
def add_documents( def add_documents(
self, self,
documents: List[Document], documents: list[Document],
ids: Optional[List[str]] = None, ids: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> list[str]:
"""Add documents to the store.""" """Add documents to the store."""
texts = [doc.page_content for doc in documents] texts = [doc.page_content for doc in documents]
vectors = self.embedding.embed_documents(texts) vectors = self.embedding.embed_documents(texts)
@ -204,8 +201,8 @@ class InMemoryVectorStore(VectorStore):
return ids_ return ids_
async def aadd_documents( async def aadd_documents(
self, documents: List[Document], ids: Optional[List[str]] = None, **kwargs: Any self, documents: list[Document], ids: Optional[list[str]] = None, **kwargs: Any
) -> List[str]: ) -> list[str]:
"""Add documents to the store.""" """Add documents to the store."""
texts = [doc.page_content for doc in documents] texts = [doc.page_content for doc in documents]
vectors = await self.embedding.aembed_documents(texts) vectors = await self.embedding.aembed_documents(texts)
@ -219,7 +216,7 @@ class InMemoryVectorStore(VectorStore):
id_iterator: Iterator[Optional[str]] = ( id_iterator: Iterator[Optional[str]] = (
iter(ids) if ids else iter(doc.id for doc in documents) iter(ids) if ids else iter(doc.id for doc in documents)
) )
ids_: List[str] = [] ids_: list[str] = []
for doc, vector in zip(documents, vectors): for doc, vector in zip(documents, vectors):
doc_id = next(id_iterator) doc_id = next(id_iterator)
@ -234,7 +231,7 @@ class InMemoryVectorStore(VectorStore):
return ids_ return ids_
def get_by_ids(self, ids: Sequence[str], /) -> List[Document]: def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
"""Get documents by their ids. """Get documents by their ids.
Args: Args:
@ -313,7 +310,7 @@ class InMemoryVectorStore(VectorStore):
"failed": [], "failed": [],
} }
async def aget_by_ids(self, ids: Sequence[str], /) -> List[Document]: async def aget_by_ids(self, ids: Sequence[str], /) -> list[Document]:
"""Async get documents by their ids. """Async get documents by their ids.
Args: Args:
@ -326,11 +323,11 @@ class InMemoryVectorStore(VectorStore):
def _similarity_search_with_score_by_vector( def _similarity_search_with_score_by_vector(
self, self,
embedding: List[float], embedding: list[float],
k: int = 4, k: int = 4,
filter: Optional[Callable[[Document], bool]] = None, filter: Optional[Callable[[Document], bool]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[Document, float, List[float]]]: ) -> list[tuple[Document, float, list[float]]]:
result = [] result = []
for doc in self.store.values(): for doc in self.store.values():
vector = doc["vector"] vector = doc["vector"]
@ -351,11 +348,11 @@ class InMemoryVectorStore(VectorStore):
def similarity_search_with_score_by_vector( def similarity_search_with_score_by_vector(
self, self,
embedding: List[float], embedding: list[float],
k: int = 4, k: int = 4,
filter: Optional[Callable[[Document], bool]] = None, filter: Optional[Callable[[Document], bool]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> list[tuple[Document, float]]:
return [ return [
(doc, similarity) (doc, similarity)
for doc, similarity, _ in self._similarity_search_with_score_by_vector( for doc, similarity, _ in self._similarity_search_with_score_by_vector(
@ -368,7 +365,7 @@ class InMemoryVectorStore(VectorStore):
query: str, query: str,
k: int = 4, k: int = 4,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> list[tuple[Document, float]]:
embedding = self.embedding.embed_query(query) embedding = self.embedding.embed_query(query)
docs = self.similarity_search_with_score_by_vector( docs = self.similarity_search_with_score_by_vector(
embedding, embedding,
@ -379,7 +376,7 @@ class InMemoryVectorStore(VectorStore):
async def asimilarity_search_with_score( async def asimilarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]: ) -> list[tuple[Document, float]]:
embedding = await self.embedding.aembed_query(query) embedding = await self.embedding.aembed_query(query)
docs = self.similarity_search_with_score_by_vector( docs = self.similarity_search_with_score_by_vector(
embedding, embedding,
@ -390,10 +387,10 @@ class InMemoryVectorStore(VectorStore):
def similarity_search_by_vector( def similarity_search_by_vector(
self, self,
embedding: List[float], embedding: list[float],
k: int = 4, k: int = 4,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
docs_and_scores = self.similarity_search_with_score_by_vector( docs_and_scores = self.similarity_search_with_score_by_vector(
embedding, embedding,
k, k,
@ -402,18 +399,18 @@ class InMemoryVectorStore(VectorStore):
return [doc for doc, _ in docs_and_scores] return [doc for doc, _ in docs_and_scores]
async def asimilarity_search_by_vector( async def asimilarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any self, embedding: list[float], k: int = 4, **kwargs: Any
) -> List[Document]: ) -> list[Document]:
return self.similarity_search_by_vector(embedding, k, **kwargs) return self.similarity_search_by_vector(embedding, k, **kwargs)
def similarity_search( def similarity_search(
self, query: str, k: int = 4, **kwargs: Any self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]: ) -> list[Document]:
return [doc for doc, _ in self.similarity_search_with_score(query, k, **kwargs)] return [doc for doc, _ in self.similarity_search_with_score(query, k, **kwargs)]
async def asimilarity_search( async def asimilarity_search(
self, query: str, k: int = 4, **kwargs: Any self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]: ) -> list[Document]:
return [ return [
doc doc
for doc, _ in await self.asimilarity_search_with_score(query, k, **kwargs) for doc, _ in await self.asimilarity_search_with_score(query, k, **kwargs)
@ -421,12 +418,12 @@ class InMemoryVectorStore(VectorStore):
def max_marginal_relevance_search_by_vector( def max_marginal_relevance_search_by_vector(
self, self,
embedding: List[float], embedding: list[float],
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
prefetch_hits = self._similarity_search_with_score_by_vector( prefetch_hits = self._similarity_search_with_score_by_vector(
embedding=embedding, embedding=embedding,
k=fetch_k, k=fetch_k,
@ -456,7 +453,7 @@ class InMemoryVectorStore(VectorStore):
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
embedding_vector = self.embedding.embed_query(query) embedding_vector = self.embedding.embed_query(query)
return self.max_marginal_relevance_search_by_vector( return self.max_marginal_relevance_search_by_vector(
embedding_vector, embedding_vector,
@ -473,7 +470,7 @@ class InMemoryVectorStore(VectorStore):
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> list[Document]:
embedding_vector = await self.embedding.aembed_query(query) embedding_vector = await self.embedding.aembed_query(query)
return self.max_marginal_relevance_search_by_vector( return self.max_marginal_relevance_search_by_vector(
embedding_vector, embedding_vector,
@ -486,9 +483,9 @@ class InMemoryVectorStore(VectorStore):
@classmethod @classmethod
def from_texts( def from_texts(
cls, cls,
texts: List[str], texts: list[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[list[dict]] = None,
**kwargs: Any, **kwargs: Any,
) -> InMemoryVectorStore: ) -> InMemoryVectorStore:
store = cls( store = cls(
@ -500,9 +497,9 @@ class InMemoryVectorStore(VectorStore):
@classmethod @classmethod
async def afrom_texts( async def afrom_texts(
cls, cls,
texts: List[str], texts: list[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[list[dict]] = None,
**kwargs: Any, **kwargs: Any,
) -> InMemoryVectorStore: ) -> InMemoryVectorStore:
store = cls( store = cls(

View File

@ -76,7 +76,7 @@ def maximal_marginal_relevance(
embedding_list: list, embedding_list: list,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
k: int = 4, k: int = 4,
) -> List[int]: ) -> list[int]:
"""Calculate maximal marginal relevance. """Calculate maximal marginal relevance.
Args: Args:

8
libs/core/poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. # This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]] [[package]]
name = "annotated-types" name = "annotated-types"
@ -1186,7 +1186,7 @@ develop = true
[package.dependencies] [package.dependencies]
httpx = "^0.27.0" httpx = "^0.27.0"
langchain-core = ">=0.3.0.dev1" langchain-core = "^0.3.0"
pytest = ">=7,<9" pytest = ">=7,<9"
syrupy = "^4" syrupy = "^4"
@ -1196,7 +1196,7 @@ url = "../standard-tests"
[[package]] [[package]]
name = "langchain-text-splitters" name = "langchain-text-splitters"
version = "0.3.0.dev1" version = "0.3.0"
description = "LangChain text splitting utilities" description = "LangChain text splitting utilities"
optional = false optional = false
python-versions = ">=3.9,<4.0" python-versions = ">=3.9,<4.0"
@ -1204,7 +1204,7 @@ files = []
develop = true develop = true
[package.dependencies] [package.dependencies]
langchain-core = "^0.3.0.dev1" langchain-core = "^0.3.0"
[package.source] [package.source]
type = "directory" type = "directory"

View File

@ -41,8 +41,8 @@ python = ">=3.12.4"
[tool.poetry.extras] [tool.poetry.extras]
[tool.ruff.lint] [tool.ruff.lint]
select = [ "B", "E", "F", "I", "T201", "UP",] select = ["B", "E", "F", "I", "T201", "UP"]
ignore = [ "UP006", "UP007",] ignore = ["UP007"]
[tool.coverage.run] [tool.coverage.run]
omit = [ "tests/*",] omit = [ "tests/*",]

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, List from typing import Any
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
@ -26,7 +26,7 @@ class FakeAsyncTracer(AsyncBaseTracer):
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the tracer.""" """Initialize the tracer."""
super().__init__() super().__init__()
self.runs: List[Run] = [] self.runs: list[Run] = []
async def _persist_run(self, run: Run) -> None: async def _persist_run(self, run: Run) -> None:
self.runs.append(run) self.runs.append(run)

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, List from typing import Any
from unittest.mock import MagicMock from unittest.mock import MagicMock
from uuid import uuid4 from uuid import uuid4
@ -30,7 +30,7 @@ class FakeTracer(BaseTracer):
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the tracer.""" """Initialize the tracer."""
super().__init__() super().__init__()
self.runs: List[Run] = [] self.runs: list[Run] = []
def _persist_run(self, run: Run) -> None: def _persist_run(self, run: Run) -> None:
"""Persist a run.""" """Persist a run."""

View File

@ -7,7 +7,7 @@ the relevant methods.
from __future__ import annotations from __future__ import annotations
import uuid import uuid
from typing import Any, Dict, Iterable, List, Optional, Sequence from typing import Any, Iterable, Optional, Sequence
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
@ -18,19 +18,19 @@ class CustomAddTextsVectorstore(VectorStore):
"""A vectorstore that only implements add texts.""" """A vectorstore that only implements add texts."""
def __init__(self) -> None: def __init__(self) -> None:
self.store: Dict[str, Document] = {} self.store: dict[str, Document] = {}
def add_texts( def add_texts(
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[list[dict]] = None,
# One of the kwargs should be `ids` which is a list of ids # One of the kwargs should be `ids` which is a list of ids
# associated with the texts. # associated with the texts.
# This is not yet enforced in the type signature for backwards compatibility # This is not yet enforced in the type signature for backwards compatibility
# with existing implementations. # with existing implementations.
ids: Optional[List[str]] = None, ids: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> list[str]:
if not isinstance(texts, list): if not isinstance(texts, list):
texts = list(texts) texts = list(texts)
ids_iter = iter(ids or []) ids_iter = iter(ids or [])
@ -46,14 +46,14 @@ class CustomAddTextsVectorstore(VectorStore):
ids_.append(id_) ids_.append(id_)
return ids_ return ids_
def get_by_ids(self, ids: Sequence[str], /) -> List[Document]: def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
return [self.store[id] for id in ids if id in self.store] return [self.store[id] for id in ids if id in self.store]
def from_texts( # type: ignore def from_texts( # type: ignore
cls, cls,
texts: List[str], texts: list[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[list[dict]] = None,
**kwargs: Any, **kwargs: Any,
) -> CustomAddTextsVectorstore: ) -> CustomAddTextsVectorstore:
vectorstore = CustomAddTextsVectorstore() vectorstore = CustomAddTextsVectorstore()
@ -62,7 +62,7 @@ class CustomAddTextsVectorstore(VectorStore):
def similarity_search( def similarity_search(
self, query: str, k: int = 4, **kwargs: Any self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]: ) -> list[Document]:
raise NotImplementedError() raise NotImplementedError()