core[patch]: Add ruff rule UP006(use PEP585 annotations) (#26574)

* Added rules `UPD006` now that Pydantic is v2+

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Christophe Bornet 2024-09-17 23:22:50 +02:00 committed by GitHub
parent 2ef4c9466f
commit 3a99467ccb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
70 changed files with 1222 additions and 1286 deletions

View File

@ -25,7 +25,7 @@ The schemas for the agents themselves are defined in langchain.agents.agent.
from __future__ import annotations
import json
from typing import Any, List, Literal, Sequence, Union
from typing import Any, Literal, Sequence, Union
from langchain_core.load.serializable import Serializable
from langchain_core.messages import (
@ -71,7 +71,7 @@ class AgentAction(Serializable):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "agent"]."""
return ["langchain", "schema", "agent"]
@ -145,7 +145,7 @@ class AgentFinish(Serializable):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "agent"]

View File

@ -23,7 +23,7 @@ Cache directly competes with Memory. See documentation for Pros and Cons.
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Sequence, Tuple
from typing import Any, Optional, Sequence
from langchain_core.outputs import Generation
from langchain_core.runnables import run_in_executor
@ -157,7 +157,7 @@ class InMemoryCache(BaseCache):
Raises:
ValueError: If maxsize is less than or equal to 0.
"""
self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {}
self._cache: dict[tuple[str, str], RETURN_VAL_TYPE] = {}
if maxsize is not None and maxsize <= 0:
raise ValueError("maxsize must be greater than 0")
self._maxsize = maxsize

View File

@ -3,7 +3,7 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union
from uuid import UUID
from tenacity import RetryCallState
@ -118,7 +118,7 @@ class ChainManagerMixin:
def on_chain_end(
self,
outputs: Dict[str, Any],
outputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@ -222,13 +222,13 @@ class CallbackManagerMixin:
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
serialized: dict[str, Any],
prompts: list[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM starts running.
@ -249,13 +249,13 @@ class CallbackManagerMixin:
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running.
@ -280,13 +280,13 @@ class CallbackManagerMixin:
def on_retriever_start(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when the Retriever starts running.
@ -303,13 +303,13 @@ class CallbackManagerMixin:
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
serialized: dict[str, Any],
inputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chain starts running.
@ -326,14 +326,14 @@ class CallbackManagerMixin:
def on_tool_start(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
inputs: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when the tool starts running.
@ -393,8 +393,8 @@ class RunManagerMixin:
data: Any,
*,
run_id: UUID,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Override to define a handler for a custom event.
@ -470,13 +470,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
serialized: dict[str, Any],
prompts: list[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM starts running.
@ -497,13 +497,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running.
@ -533,7 +533,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
"""Run on new LLM token. Only available when streaming is enabled.
@ -554,7 +554,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM ends running.
@ -573,7 +573,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM errors.
@ -590,13 +590,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
serialized: dict[str, Any],
inputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when a chain starts running.
@ -613,11 +613,11 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_chain_end(
self,
outputs: Dict[str, Any],
outputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
"""Run when a chain ends running.
@ -636,7 +636,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
"""Run when chain errors.
@ -651,14 +651,14 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_tool_start(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
inputs: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when the tool starts running.
@ -680,7 +680,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
"""Run when the tool ends running.
@ -699,7 +699,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool errors.
@ -718,7 +718,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
"""Run on an arbitrary text.
@ -754,7 +754,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
"""Run on agent action.
@ -773,7 +773,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
"""Run on the agent end.
@ -788,13 +788,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_retriever_start(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run on the retriever start.
@ -815,7 +815,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
"""Run on the retriever end.
@ -833,7 +833,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
"""Run on retriever error.
@ -852,8 +852,8 @@ class AsyncCallbackHandler(BaseCallbackHandler):
data: Any,
*,
run_id: UUID,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Override to define a handler for a custom event.
@ -880,14 +880,14 @@ class BaseCallbackManager(CallbackManagerMixin):
def __init__(
self,
handlers: List[BaseCallbackHandler],
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
handlers: list[BaseCallbackHandler],
inheritable_handlers: Optional[list[BaseCallbackHandler]] = None,
parent_run_id: Optional[UUID] = None,
*,
tags: Optional[List[str]] = None,
inheritable_tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
inheritable_tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
inheritable_metadata: Optional[dict[str, Any]] = None,
) -> None:
"""Initialize callback manager.
@ -901,8 +901,8 @@ class BaseCallbackManager(CallbackManagerMixin):
Default is None.
metadata (Optional[Dict[str, Any]]): The metadata. Default is None.
"""
self.handlers: List[BaseCallbackHandler] = handlers
self.inheritable_handlers: List[BaseCallbackHandler] = (
self.handlers: list[BaseCallbackHandler] = handlers
self.inheritable_handlers: list[BaseCallbackHandler] = (
inheritable_handlers or []
)
self.parent_run_id: Optional[UUID] = parent_run_id
@ -1002,7 +1002,7 @@ class BaseCallbackManager(CallbackManagerMixin):
self.inheritable_handlers.remove(handler)
def set_handlers(
self, handlers: List[BaseCallbackHandler], inherit: bool = True
self, handlers: list[BaseCallbackHandler], inherit: bool = True
) -> None:
"""Set handlers as the only handlers on the callback manager.
@ -1024,7 +1024,7 @@ class BaseCallbackManager(CallbackManagerMixin):
"""
self.set_handlers([handler], inherit=inherit)
def add_tags(self, tags: List[str], inherit: bool = True) -> None:
def add_tags(self, tags: list[str], inherit: bool = True) -> None:
"""Add tags to the callback manager.
Args:
@ -1038,7 +1038,7 @@ class BaseCallbackManager(CallbackManagerMixin):
if inherit:
self.inheritable_tags.extend(tags)
def remove_tags(self, tags: List[str]) -> None:
def remove_tags(self, tags: list[str]) -> None:
"""Remove tags from the callback manager.
Args:
@ -1048,7 +1048,7 @@ class BaseCallbackManager(CallbackManagerMixin):
self.tags.remove(tag)
self.inheritable_tags.remove(tag)
def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None:
def add_metadata(self, metadata: dict[str, Any], inherit: bool = True) -> None:
"""Add metadata to the callback manager.
Args:
@ -1059,7 +1059,7 @@ class BaseCallbackManager(CallbackManagerMixin):
if inherit:
self.inheritable_metadata.update(metadata)
def remove_metadata(self, keys: List[str]) -> None:
def remove_metadata(self, keys: list[str]) -> None:
"""Remove metadata from the callback manager.
Args:

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Dict, Optional, TextIO, cast
from typing import Any, Optional, TextIO, cast
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
@ -35,7 +35,7 @@ class FileCallbackHandler(BaseCallbackHandler):
self.file.close()
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain.
@ -51,7 +51,7 @@ class FileCallbackHandler(BaseCallbackHandler):
file=self.file,
)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain.
Args:

View File

@ -14,9 +14,7 @@ from typing import (
AsyncGenerator,
Callable,
Coroutine,
Dict,
Generator,
List,
Optional,
Sequence,
Type,
@ -64,12 +62,12 @@ def trace_as_chain_group(
group_name: str,
callback_manager: Optional[CallbackManager] = None,
*,
inputs: Optional[Dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
project_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None,
run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
) -> Generator[CallbackManagerForChainGroup, None, None]:
"""Get a callback manager for a chain group in a context manager.
Useful for grouping different calls together as a single run even if
@ -144,12 +142,12 @@ async def atrace_as_chain_group(
group_name: str,
callback_manager: Optional[AsyncCallbackManager] = None,
*,
inputs: Optional[Dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
project_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None,
run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]:
"""Get an async callback manager for a chain group in a context manager.
Useful for grouping different async calls together as a single run even if
@ -240,7 +238,7 @@ def shielded(func: Func) -> Func:
def handle_event(
handlers: List[BaseCallbackHandler],
handlers: list[BaseCallbackHandler],
event_name: str,
ignore_condition_name: Optional[str],
*args: Any,
@ -258,10 +256,10 @@ def handle_event(
*args: The arguments to pass to the event handler.
**kwargs: The keyword arguments to pass to the event handler
"""
coros: List[Coroutine[Any, Any, Any]] = []
coros: list[Coroutine[Any, Any, Any]] = []
try:
message_strings: Optional[List[str]] = None
message_strings: Optional[list[str]] = None
for handler in handlers:
try:
if ignore_condition_name is None or not getattr(
@ -318,7 +316,7 @@ def handle_event(
_run_coros(coros)
def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
def _run_coros(coros: list[Coroutine[Any, Any, Any]]) -> None:
if hasattr(asyncio, "Runner"):
# Python 3.11+
# Run the coroutines in a new event loop, taking care to
@ -399,7 +397,7 @@ async def _ahandle_event_for_handler(
async def ahandle_event(
handlers: List[BaseCallbackHandler],
handlers: list[BaseCallbackHandler],
event_name: str,
ignore_condition_name: Optional[str],
*args: Any,
@ -446,13 +444,13 @@ class BaseRunManager(RunManagerMixin):
self,
*,
run_id: UUID,
handlers: List[BaseCallbackHandler],
inheritable_handlers: List[BaseCallbackHandler],
handlers: list[BaseCallbackHandler],
inheritable_handlers: list[BaseCallbackHandler],
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
inheritable_tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
inheritable_tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
inheritable_metadata: Optional[dict[str, Any]] = None,
) -> None:
"""Initialize the run manager.
@ -481,7 +479,7 @@ class BaseRunManager(RunManagerMixin):
self.inheritable_metadata = inheritable_metadata or {}
@classmethod
def get_noop_manager(cls: Type[BRM]) -> BRM:
def get_noop_manager(cls: type[BRM]) -> BRM:
"""Return a manager that doesn't perform any operations.
Returns:
@ -824,7 +822,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
"""Callback manager for chain run."""
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
def on_chain_end(self, outputs: Union[dict[str, Any], Any], **kwargs: Any) -> None:
"""Run when chain ends running.
Args:
@ -929,7 +927,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
@shielded
async def on_chain_end(
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
self, outputs: Union[dict[str, Any], Any], **kwargs: Any
) -> None:
"""Run when a chain ends running.
@ -1248,11 +1246,11 @@ class CallbackManager(BaseCallbackManager):
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
serialized: dict[str, Any],
prompts: list[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> List[CallbackManagerForLLMRun]:
) -> list[CallbackManagerForLLMRun]:
"""Run when LLM starts running.
Args:
@ -1299,11 +1297,11 @@ class CallbackManager(BaseCallbackManager):
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> List[CallbackManagerForLLMRun]:
) -> list[CallbackManagerForLLMRun]:
"""Run when LLM starts running.
Args:
@ -1354,8 +1352,8 @@ class CallbackManager(BaseCallbackManager):
def on_chain_start(
self,
serialized: Optional[Dict[str, Any]],
inputs: Union[Dict[str, Any], Any],
serialized: Optional[dict[str, Any]],
inputs: Union[dict[str, Any], Any],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForChainRun:
@ -1398,11 +1396,11 @@ class CallbackManager(BaseCallbackManager):
def on_tool_start(
self,
serialized: Optional[Dict[str, Any]],
serialized: Optional[dict[str, Any]],
input_str: str,
run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None,
inputs: Optional[Dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> CallbackManagerForToolRun:
"""Run when tool starts running.
@ -1453,7 +1451,7 @@ class CallbackManager(BaseCallbackManager):
def on_retriever_start(
self,
serialized: Optional[Dict[str, Any]],
serialized: Optional[dict[str, Any]],
query: str,
run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None,
@ -1541,10 +1539,10 @@ class CallbackManager(BaseCallbackManager):
inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None,
verbose: bool = False,
inheritable_tags: Optional[List[str]] = None,
local_tags: Optional[List[str]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
local_metadata: Optional[Dict[str, Any]] = None,
inheritable_tags: Optional[list[str]] = None,
local_tags: Optional[list[str]] = None,
inheritable_metadata: Optional[dict[str, Any]] = None,
local_metadata: Optional[dict[str, Any]] = None,
) -> CallbackManager:
"""Configure the callback manager.
@ -1583,8 +1581,8 @@ class CallbackManagerForChainGroup(CallbackManager):
def __init__(
self,
handlers: List[BaseCallbackHandler],
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
handlers: list[BaseCallbackHandler],
inheritable_handlers: Optional[list[BaseCallbackHandler]] = None,
parent_run_id: Optional[UUID] = None,
*,
parent_run_manager: CallbackManagerForChainRun,
@ -1681,7 +1679,7 @@ class CallbackManagerForChainGroup(CallbackManager):
manager.add_handler(handler, inherit=True)
return manager
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
def on_chain_end(self, outputs: Union[dict[str, Any], Any], **kwargs: Any) -> None:
"""Run when traced chain group ends.
Args:
@ -1716,11 +1714,11 @@ class AsyncCallbackManager(BaseCallbackManager):
async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
serialized: dict[str, Any],
prompts: list[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> List[AsyncCallbackManagerForLLMRun]:
) -> list[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running.
Args:
@ -1779,11 +1777,11 @@ class AsyncCallbackManager(BaseCallbackManager):
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> List[AsyncCallbackManagerForLLMRun]:
) -> list[AsyncCallbackManagerForLLMRun]:
"""Async run when LLM starts running.
Args:
@ -1840,8 +1838,8 @@ class AsyncCallbackManager(BaseCallbackManager):
async def on_chain_start(
self,
serialized: Optional[Dict[str, Any]],
inputs: Union[Dict[str, Any], Any],
serialized: Optional[dict[str, Any]],
inputs: Union[dict[str, Any], Any],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForChainRun:
@ -1886,7 +1884,7 @@ class AsyncCallbackManager(BaseCallbackManager):
async def on_tool_start(
self,
serialized: Optional[Dict[str, Any]],
serialized: Optional[dict[str, Any]],
input_str: str,
run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None,
@ -1975,7 +1973,7 @@ class AsyncCallbackManager(BaseCallbackManager):
async def on_retriever_start(
self,
serialized: Optional[Dict[str, Any]],
serialized: Optional[dict[str, Any]],
query: str,
run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None,
@ -2027,10 +2025,10 @@ class AsyncCallbackManager(BaseCallbackManager):
inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None,
verbose: bool = False,
inheritable_tags: Optional[List[str]] = None,
local_tags: Optional[List[str]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
local_metadata: Optional[Dict[str, Any]] = None,
inheritable_tags: Optional[list[str]] = None,
local_tags: Optional[list[str]] = None,
inheritable_metadata: Optional[dict[str, Any]] = None,
local_metadata: Optional[dict[str, Any]] = None,
) -> AsyncCallbackManager:
"""Configure the async callback manager.
@ -2069,8 +2067,8 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
def __init__(
self,
handlers: List[BaseCallbackHandler],
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
handlers: list[BaseCallbackHandler],
inheritable_handlers: Optional[list[BaseCallbackHandler]] = None,
parent_run_id: Optional[UUID] = None,
*,
parent_run_manager: AsyncCallbackManagerForChainRun,
@ -2169,7 +2167,7 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
return manager
async def on_chain_end(
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
self, outputs: Union[dict[str, Any], Any], **kwargs: Any
) -> None:
"""Run when traced chain group ends.
@ -2202,14 +2200,14 @@ H = TypeVar("H", bound=BaseCallbackHandler, covariant=True)
def _configure(
callback_manager_cls: Type[T],
callback_manager_cls: type[T],
inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None,
verbose: bool = False,
inheritable_tags: Optional[List[str]] = None,
local_tags: Optional[List[str]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
local_metadata: Optional[Dict[str, Any]] = None,
inheritable_tags: Optional[list[str]] = None,
local_tags: Optional[list[str]] = None,
inheritable_metadata: Optional[dict[str, Any]] = None,
local_metadata: Optional[dict[str, Any]] = None,
) -> T:
"""Configure the callback manager.

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Optional
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.utils import print_text
@ -23,7 +23,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
self.color = color
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain.
@ -35,7 +35,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") # noqa: T201
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain.
Args:

View File

@ -3,7 +3,7 @@
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Any
from langchain_core.callbacks.base import BaseCallbackHandler
@ -17,7 +17,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
"""Callback handler for streaming. Only works with LLMs that support streaming."""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None:
"""Run when LLM starts running.
@ -29,8 +29,8 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
**kwargs: Any,
) -> None:
"""Run when LLM starts running.
@ -68,7 +68,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
"""
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
) -> None:
"""Run when a chain starts running.
@ -78,7 +78,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
**kwargs (Any): Additional keyword arguments.
"""
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
"""Run when a chain ends running.
Args:
@ -95,7 +95,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
"""
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
self, serialized: dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when the tool starts running.

View File

@ -18,7 +18,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Sequence, Union
from typing import Sequence, Union
from pydantic import BaseModel, Field
@ -87,7 +87,7 @@ class BaseChatMessageHistory(ABC):
f.write("[]")
"""
messages: List[BaseMessage]
messages: list[BaseMessage]
"""A property or attribute that returns a list of messages.
In general, getting the messages may involve IO to the underlying
@ -95,7 +95,7 @@ class BaseChatMessageHistory(ABC):
latency.
"""
async def aget_messages(self) -> List[BaseMessage]:
async def aget_messages(self) -> list[BaseMessage]:
"""Async version of getting messages.
Can over-ride this method to provide an efficient async implementation.
@ -204,10 +204,10 @@ class InMemoryChatMessageHistory(BaseChatMessageHistory, BaseModel):
Stores messages in a memory list.
"""
messages: List[BaseMessage] = Field(default_factory=list)
messages: list[BaseMessage] = Field(default_factory=list)
"""A list of messages stored in memory."""
async def aget_messages(self) -> List[BaseMessage]:
async def aget_messages(self) -> list[BaseMessage]:
"""Async version of getting messages.
Can over-ride this method to provide an efficient async implementation.

View File

@ -3,7 +3,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Optional
from typing import TYPE_CHECKING, AsyncIterator, Iterator, Optional
from langchain_core.documents import Document
from langchain_core.runnables import run_in_executor
@ -25,17 +25,17 @@ class BaseLoader(ABC): # noqa: B024
# Sub-classes should not implement this method directly. Instead, they
# should implement the lazy load method.
def load(self) -> List[Document]:
def load(self) -> list[Document]:
"""Load data into Document objects."""
return list(self.lazy_load())
async def aload(self) -> List[Document]:
async def aload(self) -> list[Document]:
"""Load data into Document objects."""
return [document async for document in self.alazy_load()]
def load_and_split(
self, text_splitter: Optional[TextSplitter] = None
) -> List[Document]:
) -> list[Document]:
"""Load Documents and split into chunks. Chunks are returned as Documents.
Do not override this method. It should be considered to be deprecated!
@ -108,7 +108,7 @@ class BaseBlobParser(ABC):
Generator of documents
"""
def parse(self, blob: Blob) -> List[Document]:
def parse(self, blob: Blob) -> list[Document]:
"""Eagerly parse the blob into a document or documents.
This is a convenience method for interactive development environment.

View File

@ -4,7 +4,7 @@ import contextlib
import mimetypes
from io import BufferedReader, BytesIO
from pathlib import PurePath
from typing import Any, Dict, Generator, List, Literal, Optional, Union, cast
from typing import Any, Generator, Literal, Optional, Union, cast
from pydantic import ConfigDict, Field, field_validator, model_validator
@ -138,7 +138,7 @@ class Blob(BaseMedia):
@model_validator(mode="before")
@classmethod
def check_blob_is_valid(cls, values: Dict[str, Any]) -> Any:
def check_blob_is_valid(cls, values: dict[str, Any]) -> Any:
"""Verify that either data or path is provided."""
if "data" not in values and "path" not in values:
raise ValueError("Either data or path must be provided")
@ -285,7 +285,7 @@ class Document(BaseMedia):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "document"]

View File

@ -3,7 +3,7 @@
from __future__ import annotations
from abc import ABC
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel, ConfigDict
@ -15,7 +15,7 @@ if TYPE_CHECKING:
from langchain_core.embeddings import Embeddings
def sorted_values(values: Dict[str, str]) -> List[Any]:
def sorted_values(values: dict[str, str]) -> list[Any]:
"""Return a list of values in dict sorted by key.
Args:
@ -35,12 +35,12 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
"""VectorStore that contains information about examples."""
k: int = 4
"""Number of examples to select."""
example_keys: Optional[List[str]] = None
example_keys: Optional[list[str]] = None
"""Optional keys to filter examples to."""
input_keys: Optional[List[str]] = None
input_keys: Optional[list[str]] = None
"""Optional keys to filter input to. If provided, the search is based on
the input variables instead of all variables."""
vectorstore_kwargs: Optional[Dict[str, Any]] = None
vectorstore_kwargs: Optional[dict[str, Any]] = None
"""Extra arguments passed to similarity_search function of the vectorstore."""
model_config = ConfigDict(
@ -50,14 +50,14 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
@staticmethod
def _example_to_text(
example: Dict[str, str], input_keys: Optional[List[str]]
example: dict[str, str], input_keys: Optional[list[str]]
) -> str:
if input_keys:
return " ".join(sorted_values({key: example[key] for key in input_keys}))
else:
return " ".join(sorted_values(example))
def _documents_to_examples(self, documents: List[Document]) -> List[dict]:
def _documents_to_examples(self, documents: list[Document]) -> list[dict]:
# Get the examples from the metadata.
# This assumes that examples are stored in metadata.
examples = [dict(e.metadata) for e in documents]
@ -66,7 +66,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
examples = [{k: eg[k] for k in self.example_keys} for eg in examples]
return examples
def add_example(self, example: Dict[str, str]) -> str:
def add_example(self, example: dict[str, str]) -> str:
"""Add a new example to vectorstore.
Args:
@ -81,7 +81,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
)
return ids[0]
async def aadd_example(self, example: Dict[str, str]) -> str:
async def aadd_example(self, example: dict[str, str]) -> str:
"""Async add new example to vectorstore.
Args:
@ -100,7 +100,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
"""Select examples based on semantic similarity."""
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Select examples based on semantic similarity.
Args:
@ -118,7 +118,7 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
)
return self._documents_to_examples(example_docs)
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Asynchronously select examples based on semantic similarity.
Args:
@ -139,13 +139,13 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
@classmethod
def from_examples(
cls,
examples: List[dict],
examples: list[dict],
embeddings: Embeddings,
vectorstore_cls: Type[VectorStore],
vectorstore_cls: type[VectorStore],
k: int = 4,
input_keys: Optional[List[str]] = None,
input_keys: Optional[list[str]] = None,
*,
example_keys: Optional[List[str]] = None,
example_keys: Optional[list[str]] = None,
vectorstore_kwargs: Optional[dict] = None,
**vectorstore_cls_kwargs: Any,
) -> SemanticSimilarityExampleSelector:
@ -183,13 +183,13 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
@classmethod
async def afrom_examples(
cls,
examples: List[dict],
examples: list[dict],
embeddings: Embeddings,
vectorstore_cls: Type[VectorStore],
vectorstore_cls: type[VectorStore],
k: int = 4,
input_keys: Optional[List[str]] = None,
input_keys: Optional[list[str]] = None,
*,
example_keys: Optional[List[str]] = None,
example_keys: Optional[list[str]] = None,
vectorstore_kwargs: Optional[dict] = None,
**vectorstore_cls_kwargs: Any,
) -> SemanticSimilarityExampleSelector:
@ -235,7 +235,7 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
fetch_k: int = 20
"""Number of examples to fetch to rerank."""
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Select examples based on Max Marginal Relevance.
Args:
@ -251,7 +251,7 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
)
return self._documents_to_examples(example_docs)
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Asynchronously select examples based on Max Marginal Relevance.
Args:
@ -270,13 +270,13 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
@classmethod
def from_examples(
cls,
examples: List[dict],
examples: list[dict],
embeddings: Embeddings,
vectorstore_cls: Type[VectorStore],
vectorstore_cls: type[VectorStore],
k: int = 4,
input_keys: Optional[List[str]] = None,
input_keys: Optional[list[str]] = None,
fetch_k: int = 20,
example_keys: Optional[List[str]] = None,
example_keys: Optional[list[str]] = None,
vectorstore_kwargs: Optional[dict] = None,
**vectorstore_cls_kwargs: Any,
) -> MaxMarginalRelevanceExampleSelector:
@ -317,14 +317,14 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
@classmethod
async def afrom_examples(
cls,
examples: List[dict],
examples: list[dict],
embeddings: Embeddings,
vectorstore_cls: Type[VectorStore],
vectorstore_cls: type[VectorStore],
*,
k: int = 4,
input_keys: Optional[List[str]] = None,
input_keys: Optional[list[str]] = None,
fetch_k: int = 20,
example_keys: Optional[List[str]] = None,
example_keys: Optional[list[str]] = None,
vectorstore_kwargs: Optional[dict] = None,
**vectorstore_cls_kwargs: Any,
) -> MaxMarginalRelevanceExampleSelector:

View File

@ -8,7 +8,6 @@ from typing import (
Collection,
Iterable,
Iterator,
List,
Optional,
)
@ -68,7 +67,7 @@ class Node(Serializable):
"""Text contained by the node."""
metadata: dict = Field(default_factory=dict)
"""Metadata for the node."""
links: List[Link] = Field(default_factory=list)
links: list[Link] = Field(default_factory=list)
"""Links associated with the node."""
@ -189,7 +188,7 @@ class GraphVectorStore(VectorStore):
*,
ids: Optional[Iterable[str]] = None,
**kwargs: Any,
) -> List[str]:
) -> list[str]:
"""Run more texts through the embeddings and add to the vectorstore.
The Links present in the metadata field `links` will be extracted to create
@ -237,7 +236,7 @@ class GraphVectorStore(VectorStore):
*,
ids: Optional[Iterable[str]] = None,
**kwargs: Any,
) -> List[str]:
) -> list[str]:
"""Run more texts through the embeddings and add to the vectorstore.
The Links present in the metadata field `links` will be extracted to create
@ -282,7 +281,7 @@ class GraphVectorStore(VectorStore):
self,
documents: Iterable[Document],
**kwargs: Any,
) -> List[str]:
) -> list[str]:
"""Run more documents through the embeddings and add to the vectorstore.
The Links present in the document metadata field `links` will be extracted to
@ -332,7 +331,7 @@ class GraphVectorStore(VectorStore):
self,
documents: Iterable[Document],
**kwargs: Any,
) -> List[str]:
) -> list[str]:
"""Run more documents through the embeddings and add to the vectorstore.
The Links present in the document metadata field `links` will be extracted to
@ -535,7 +534,7 @@ class GraphVectorStore(VectorStore):
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
) -> list[Document]:
return list(self.traversal_search(query, k=k, depth=0))
def max_marginal_relevance_search(
@ -545,7 +544,7 @@ class GraphVectorStore(VectorStore):
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
return list(
self.mmr_traversal_search(
query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0
@ -554,10 +553,10 @@ class GraphVectorStore(VectorStore):
async def asimilarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
) -> list[Document]:
return [doc async for doc in self.atraversal_search(query, k=k, depth=0)]
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
def search(self, query: str, search_type: str, **kwargs: Any) -> list[Document]:
if search_type == "similarity":
return self.similarity_search(query, **kwargs)
elif search_type == "similarity_score_threshold":
@ -580,7 +579,7 @@ class GraphVectorStore(VectorStore):
async def asearch(
self, query: str, search_type: str, **kwargs: Any
) -> List[Document]:
) -> list[Document]:
if search_type == "similarity":
return await self.asimilarity_search(query, **kwargs)
elif search_type == "similarity_score_threshold":
@ -679,7 +678,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
) -> list[Document]:
if self.search_type == "traversal":
return list(self.vectorstore.traversal_search(query, **self.search_kwargs))
elif self.search_type == "mmr_traversal":
@ -691,7 +690,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
) -> list[Document]:
if self.search_type == "traversal":
return [
doc

View File

@ -11,14 +11,11 @@ from typing import (
AsyncIterable,
AsyncIterator,
Callable,
Dict,
Iterable,
Iterator,
List,
Literal,
Optional,
Sequence,
Set,
TypedDict,
TypeVar,
Union,
@ -71,7 +68,7 @@ class _HashedDocument(Document):
@model_validator(mode="before")
@classmethod
def calculate_hashes(cls, values: Dict[str, Any]) -> Any:
def calculate_hashes(cls, values: dict[str, Any]) -> Any:
"""Root validator to calculate content and metadata hash."""
content = values.get("page_content", "")
metadata = values.get("metadata", {})
@ -125,7 +122,7 @@ class _HashedDocument(Document):
)
def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
def _batch(size: int, iterable: Iterable[T]) -> Iterator[list[T]]:
"""Utility batching function."""
it = iter(iterable)
while True:
@ -135,9 +132,9 @@ def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
yield chunk
async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[List[T]]:
async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[list[T]]:
"""Utility batching function."""
batch: List[T] = []
batch: list[T] = []
async for element in iterable:
if len(batch) < size:
batch.append(element)
@ -171,7 +168,7 @@ def _deduplicate_in_order(
hashed_documents: Iterable[_HashedDocument],
) -> Iterator[_HashedDocument]:
"""Deduplicate a list of hashed documents while preserving order."""
seen: Set[str] = set()
seen: set[str] = set()
for hashed_doc in hashed_documents:
if hashed_doc.hash_ not in seen:
@ -349,7 +346,7 @@ def index(
uids = []
docs_to_index = []
uids_to_refresh = []
seen_docs: Set[str] = set()
seen_docs: set[str] = set()
for hashed_doc, doc_exists in zip(hashed_docs, exists_batch):
if doc_exists:
if force_update:
@ -589,7 +586,7 @@ async def aindex(
uids: list[str] = []
docs_to_index: list[Document] = []
uids_to_refresh = []
seen_docs: Set[str] = set()
seen_docs: set[str] = set()
for hashed_doc, doc_exists in zip(hashed_docs, exists_batch):
if doc_exists:
if force_update:

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import abc
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Sequence, TypedDict
from typing import Any, Optional, Sequence, TypedDict
from langchain_core._api import beta
from langchain_core.documents import Document
@ -144,7 +144,7 @@ class RecordManager(ABC):
"""
@abstractmethod
def exists(self, keys: Sequence[str]) -> List[bool]:
def exists(self, keys: Sequence[str]) -> list[bool]:
"""Check if the provided keys exist in the database.
Args:
@ -155,7 +155,7 @@ class RecordManager(ABC):
"""
@abstractmethod
async def aexists(self, keys: Sequence[str]) -> List[bool]:
async def aexists(self, keys: Sequence[str]) -> list[bool]:
"""Asynchronously check if the provided keys exist in the database.
Args:
@ -173,7 +173,7 @@ class RecordManager(ABC):
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
) -> list[str]:
"""List records in the database based on the provided filters.
Args:
@ -194,7 +194,7 @@ class RecordManager(ABC):
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
) -> list[str]:
"""Asynchronously list records in the database based on the provided filters.
Args:
@ -241,7 +241,7 @@ class InMemoryRecordManager(RecordManager):
super().__init__(namespace)
# Each key points to a dictionary
# of {'group_id': group_id, 'updated_at': timestamp}
self.records: Dict[str, _Record] = {}
self.records: dict[str, _Record] = {}
self.namespace = namespace
def create_schema(self) -> None:
@ -325,7 +325,7 @@ class InMemoryRecordManager(RecordManager):
"""
self.update(keys, group_ids=group_ids, time_at_least=time_at_least)
def exists(self, keys: Sequence[str]) -> List[bool]:
def exists(self, keys: Sequence[str]) -> list[bool]:
"""Check if the provided keys exist in the database.
Args:
@ -336,7 +336,7 @@ class InMemoryRecordManager(RecordManager):
"""
return [key in self.records for key in keys]
async def aexists(self, keys: Sequence[str]) -> List[bool]:
async def aexists(self, keys: Sequence[str]) -> list[bool]:
"""Async check if the provided keys exist in the database.
Args:
@ -354,7 +354,7 @@ class InMemoryRecordManager(RecordManager):
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
) -> list[str]:
"""List records in the database based on the provided filters.
Args:
@ -390,7 +390,7 @@ class InMemoryRecordManager(RecordManager):
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
) -> list[str]:
"""Async list records in the database based on the provided filters.
Args:
@ -449,9 +449,9 @@ class UpsertResponse(TypedDict):
indexed to avoid this issue.
"""
succeeded: List[str]
succeeded: list[str]
"""The IDs that were successfully indexed."""
failed: List[str]
failed: list[str]
"""The IDs that failed to index."""
@ -562,7 +562,7 @@ class DocumentIndex(BaseRetriever):
)
@abc.abstractmethod
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> DeleteResponse:
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse:
"""Delete by IDs or other criteria.
Calling delete without any input parameters should raise a ValueError!
@ -579,7 +579,7 @@ class DocumentIndex(BaseRetriever):
"""
async def adelete(
self, ids: Optional[List[str]] = None, **kwargs: Any
self, ids: Optional[list[str]] = None, **kwargs: Any
) -> DeleteResponse:
"""Delete by IDs or other criteria. Async variant.
@ -607,7 +607,7 @@ class DocumentIndex(BaseRetriever):
ids: Sequence[str],
/,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
"""Get documents by id.
Fewer documents may be returned than requested if some IDs are not found or
@ -633,7 +633,7 @@ class DocumentIndex(BaseRetriever):
ids: Sequence[str],
/,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
"""Get documents by id.
Fewer documents may be returned than requested if some IDs are not found or

View File

@ -6,14 +6,11 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
Set,
Type,
TypeVar,
Union,
)
@ -51,7 +48,7 @@ class LangSmithParams(TypedDict, total=False):
"""Temperature for generation."""
ls_max_tokens: Optional[int]
"""Max tokens for generation."""
ls_stop: Optional[List[str]]
ls_stop: Optional[list[str]]
"""Stop words for generation."""
@ -74,7 +71,7 @@ def get_tokenizer() -> Any:
return GPT2TokenizerFast.from_pretrained("gpt2")
def _get_token_ids_default_method(text: str) -> List[int]:
def _get_token_ids_default_method(text: str) -> list[int]:
"""Encode the text into token IDs."""
# get the cached tokenizer
tokenizer = get_tokenizer()
@ -117,11 +114,11 @@ class BaseLanguageModel(
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to add to the run trace."""
tags: Optional[List[str]] = Field(default=None, exclude=True)
tags: Optional[list[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace."""
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
metadata: Optional[dict[str, Any]] = Field(default=None, exclude=True)
"""Metadata to add to the run trace."""
custom_get_token_ids: Optional[Callable[[str], List[int]]] = Field(
custom_get_token_ids: Optional[Callable[[str], list[int]]] = Field(
default=None, exclude=True
)
"""Optional encoder to use for counting tokens."""
@ -167,8 +164,8 @@ class BaseLanguageModel(
@abstractmethod
def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
prompts: list[PromptValue],
stop: Optional[list[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
@ -202,8 +199,8 @@ class BaseLanguageModel(
@abstractmethod
async def agenerate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
prompts: list[PromptValue],
stop: Optional[list[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
@ -235,8 +232,8 @@ class BaseLanguageModel(
"""
def with_structured_output(
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
self, schema: Union[dict, type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Not implemented on this class."""
# Implement this on child class if there is a way of steering the model to
# generate responses that match a given schema.
@ -267,7 +264,7 @@ class BaseLanguageModel(
@abstractmethod
def predict_messages(
self,
messages: List[BaseMessage],
messages: list[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
@ -313,7 +310,7 @@ class BaseLanguageModel(
@abstractmethod
async def apredict_messages(
self,
messages: List[BaseMessage],
messages: list[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
@ -339,7 +336,7 @@ class BaseLanguageModel(
"""Get the identifying parameters."""
return self.lc_attributes
def get_token_ids(self, text: str) -> List[int]:
def get_token_ids(self, text: str) -> list[int]:
"""Return the ordered ids of the tokens in a text.
Args:
@ -367,7 +364,7 @@ class BaseLanguageModel(
"""
return len(self.get_token_ids(text))
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
@ -381,7 +378,7 @@ class BaseLanguageModel(
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
@classmethod
def _all_required_field_names(cls) -> Set:
def _all_required_field_names(cls) -> set:
"""DEPRECATED: Kept for backwards compatibility.
Use get_pydantic_field_names.

View File

@ -15,11 +15,9 @@ from typing import (
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Type,
Union,
cast,
)
@ -223,7 +221,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
def raise_deprecation(cls, values: dict) -> Any:
"""Raise deprecation warning if callback_manager is used.
Args:
@ -277,7 +275,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> BaseMessage:
config = ensure_config(config)
@ -300,7 +298,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> BaseMessage:
config = ensure_config(config)
@ -356,7 +354,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Iterator[BaseMessageChunk]:
if not self._should_stream(async_api=False, **{**kwargs, **{"stream": True}}):
@ -426,7 +424,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]:
if not self._should_stream(async_api=True, **{**kwargs, **{"stream": True}}):
@ -499,12 +497,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
# --- Custom methods ---
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
return {}
def _get_invocation_params(
self,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> dict:
params = self.dict()
@ -513,7 +511,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _get_ls_params(
self,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> LangSmithParams:
"""Get standard params for tracing."""
@ -550,7 +548,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
return ls_params
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str:
if self.is_lc_serializable():
params = {**kwargs, **{"stop": stop}}
param_string = str(sorted([(k, v) for k, v in params.items()]))
@ -567,12 +565,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def generate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
messages: list[list[BaseMessage]],
stop: Optional[list[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
**kwargs: Any,
@ -658,12 +656,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
async def agenerate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
messages: list[list[BaseMessage]],
stop: Optional[list[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
**kwargs: Any,
@ -777,8 +775,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
prompts: list[PromptValue],
stop: Optional[list[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
@ -787,8 +785,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
async def agenerate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
prompts: list[PromptValue],
stop: Optional[list[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
@ -799,8 +797,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _generate_with_cache(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -839,7 +837,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager=run_manager,
**kwargs,
):
chunks: List[ChatGenerationChunk] = []
chunks: list[ChatGenerationChunk] = []
for chunk in self._stream(messages, stop=stop, **kwargs):
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
if run_manager:
@ -876,8 +874,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
async def _agenerate_with_cache(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -916,7 +914,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager=run_manager,
**kwargs,
):
chunks: List[ChatGenerationChunk] = []
chunks: list[ChatGenerationChunk] = []
async for chunk in self._astream(messages, stop=stop, **kwargs):
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
if run_manager:
@ -954,8 +952,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
@abstractmethod
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -963,8 +961,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -980,8 +978,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
@ -989,8 +987,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
@ -1017,8 +1015,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
@deprecated("0.1.7", alternative="invoke", removal="1.0")
def __call__(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseMessage:
@ -1032,8 +1030,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
async def _call_async(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseMessage:
@ -1048,7 +1046,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
@deprecated("0.1.7", alternative="invoke", removal="1.0")
def call_as_llm(
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
self, message: str, stop: Optional[list[str]] = None, **kwargs: Any
) -> str:
return self.predict(message, stop=stop, **kwargs)
@ -1069,7 +1067,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
@deprecated("0.1.7", alternative="invoke", removal="1.0")
def predict_messages(
self,
messages: List[BaseMessage],
messages: list[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
@ -1099,7 +1097,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
async def apredict_messages(
self,
messages: List[BaseMessage],
messages: list[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
@ -1115,7 +1113,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _llm_type(self) -> str:
"""Return type of chat model."""
def dict(self, **kwargs: Any) -> Dict:
def dict(self, **kwargs: Any) -> dict:
"""Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type
@ -1123,18 +1121,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
tools: Sequence[Union[Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
raise NotImplementedError()
def with_structured_output(
self,
schema: Union[Dict, Type],
schema: Union[Dict, type], # noqa: UP006
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: # noqa: UP006
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
@ -1281,8 +1279,8 @@ class SimpleChatModel(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -1294,8 +1292,8 @@ class SimpleChatModel(BaseChatModel):
@abstractmethod
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
@ -1303,8 +1301,8 @@ class SimpleChatModel(BaseChatModel):
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:

View File

@ -20,8 +20,6 @@ from typing import (
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
@ -76,7 +74,7 @@ def _log_error_once(msg: str) -> None:
def create_base_retry_decorator(
error_types: List[Type[BaseException]],
error_types: list[type[BaseException]],
max_retries: int = 1,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
@ -153,10 +151,10 @@ def _resolve_cache(cache: Union[BaseCache, bool, None]) -> Optional[BaseCache]:
def get_prompts(
params: Dict[str, Any],
prompts: List[str],
params: dict[str, Any],
prompts: list[str],
cache: Optional[Union[BaseCache, bool, None]] = None,
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
) -> tuple[dict[int, list], str, list[int], list[str]]:
"""Get prompts that are already cached.
Args:
@ -189,10 +187,10 @@ def get_prompts(
async def aget_prompts(
params: Dict[str, Any],
prompts: List[str],
params: dict[str, Any],
prompts: list[str],
cache: Optional[Union[BaseCache, bool, None]] = None,
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
) -> tuple[dict[int, list], str, list[int], list[str]]:
"""Get prompts that are already cached. Async version.
Args:
@ -225,11 +223,11 @@ async def aget_prompts(
def update_cache(
cache: Union[BaseCache, bool, None],
existing_prompts: Dict[int, List],
existing_prompts: dict[int, list],
llm_string: str,
missing_prompt_idxs: List[int],
missing_prompt_idxs: list[int],
new_results: LLMResult,
prompts: List[str],
prompts: list[str],
) -> Optional[dict]:
"""Update the cache and get the LLM output.
@ -259,11 +257,11 @@ def update_cache(
async def aupdate_cache(
cache: Union[BaseCache, bool, None],
existing_prompts: Dict[int, List],
existing_prompts: dict[int, list],
llm_string: str,
missing_prompt_idxs: List[int],
missing_prompt_idxs: list[int],
new_results: LLMResult,
prompts: List[str],
prompts: list[str],
) -> Optional[dict]:
"""Update the cache and get the LLM output. Async version.
@ -306,7 +304,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
def raise_deprecation(cls, values: dict) -> Any:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
@ -324,7 +322,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
# --- Runnable methods ---
@property
def OutputType(self) -> Type[str]:
def OutputType(self) -> type[str]:
"""Get the input type for this runnable."""
return str
@ -343,7 +341,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def _get_ls_params(
self,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> LangSmithParams:
"""Get standard params for tracing."""
@ -383,7 +381,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> str:
config = ensure_config(config)
@ -407,7 +405,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> str:
config = ensure_config(config)
@ -425,12 +423,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def batch(
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[LanguageModelInput],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[str]:
) -> list[str]:
if not inputs:
return []
@ -472,12 +470,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
async def abatch(
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[LanguageModelInput],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[str]:
) -> list[str]:
if not inputs:
return []
config = get_config_list(config, len(inputs))
@ -521,7 +519,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
if type(self)._stream == BaseLLM._stream:
@ -583,7 +581,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
if (
@ -649,8 +647,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
@abstractmethod
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
prompts: list[str],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
@ -658,8 +656,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
prompts: list[str],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
@ -676,7 +674,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
@ -704,7 +702,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
@ -747,9 +745,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
prompts: list[PromptValue],
stop: Optional[list[str]] = None,
callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None,
**kwargs: Any,
) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts]
@ -757,9 +755,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
async def agenerate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
prompts: list[PromptValue],
stop: Optional[list[str]] = None,
callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None,
**kwargs: Any,
) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts]
@ -769,9 +767,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def _generate_helper(
self,
prompts: List[str],
stop: Optional[List[str]],
run_managers: List[CallbackManagerForLLMRun],
prompts: list[str],
stop: Optional[list[str]],
run_managers: list[CallbackManagerForLLMRun],
new_arg_supported: bool,
**kwargs: Any,
) -> LLMResult:
@ -802,14 +800,14 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
prompts: list[str],
stop: Optional[list[str]] = None,
callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None,
*,
tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
run_name: Optional[Union[str, List[str]]] = None,
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None,
tags: Optional[Union[list[str], list[list[str]]]] = None,
metadata: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None,
run_name: Optional[Union[str, list[str]]] = None,
run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]] = None,
**kwargs: Any,
) -> LLMResult:
"""Pass a sequence of prompts to a model and return generations.
@ -987,7 +985,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
@staticmethod
def _get_run_ids_list(
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]], prompts: list
run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]], prompts: list
) -> list:
if run_id is None:
return [None] * len(prompts)
@ -1002,9 +1000,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
async def _agenerate_helper(
self,
prompts: List[str],
stop: Optional[List[str]],
run_managers: List[AsyncCallbackManagerForLLMRun],
prompts: list[str],
stop: Optional[list[str]],
run_managers: list[AsyncCallbackManagerForLLMRun],
new_arg_supported: bool,
**kwargs: Any,
) -> LLMResult:
@ -1044,14 +1042,14 @@ class BaseLLM(BaseLanguageModel[str], ABC):
async def agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
prompts: list[str],
stop: Optional[list[str]] = None,
callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None,
*,
tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
run_name: Optional[Union[str, List[str]]] = None,
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None,
tags: Optional[Union[list[str], list[list[str]]]] = None,
metadata: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None,
run_name: Optional[Union[str, list[str]]] = None,
run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]] = None,
**kwargs: Any,
) -> LLMResult:
"""Asynchronously pass a sequence of prompts to a model and return generations.
@ -1239,11 +1237,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def __call__(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""Check Cache and run the LLM on the given prompt and input.
@ -1287,11 +1285,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
async def _call_async(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""Check Cache and run the LLM on the given prompt and input."""
@ -1318,7 +1316,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
@deprecated("0.1.7", alternative="invoke", removal="1.0")
def predict_messages(
self,
messages: List[BaseMessage],
messages: list[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
@ -1344,7 +1342,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
async def apredict_messages(
self,
messages: List[BaseMessage],
messages: list[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
@ -1367,7 +1365,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def _llm_type(self) -> str:
"""Return type of llm."""
def dict(self, **kwargs: Any) -> Dict:
def dict(self, **kwargs: Any) -> dict:
"""Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type
@ -1443,7 +1441,7 @@ class LLM(BaseLLM):
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
@ -1467,7 +1465,7 @@ class LLM(BaseLLM):
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
@ -1500,8 +1498,8 @@ class LLM(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
prompts: list[str],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
@ -1520,8 +1518,8 @@ class LLM(BaseLLM):
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
prompts: list[str],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:

View File

@ -11,7 +11,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any
from pydantic import ConfigDict
@ -55,11 +55,11 @@ class BaseMemory(Serializable, ABC):
@property
@abstractmethod
def memory_variables(self) -> List[str]:
def memory_variables(self) -> list[str]:
"""The string keys this memory class will add to chain inputs."""
@abstractmethod
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return key-value pairs given the text input to the chain.
Args:
@ -69,7 +69,7 @@ class BaseMemory(Serializable, ABC):
A dictionary of key-value pairs.
"""
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Async return key-value pairs given the text input to the chain.
Args:
@ -81,7 +81,7 @@ class BaseMemory(Serializable, ABC):
return await run_in_executor(None, self.load_memory_variables, inputs)
@abstractmethod
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
"""Save the context of this chain run to memory.
Args:
@ -90,7 +90,7 @@ class BaseMemory(Serializable, ABC):
"""
async def asave_context(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
self, inputs: dict[str, Any], outputs: dict[str, str]
) -> None:
"""Async save the context of this chain run to memory.

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union, cast
from pydantic import ConfigDict, Field, field_validator
@ -19,7 +19,7 @@ class BaseMessage(Serializable):
Messages are the inputs and outputs of ChatModels.
"""
content: Union[str, List[Union[str, Dict]]]
content: Union[str, list[Union[str, dict]]]
"""The string contents of the message."""
additional_kwargs: dict = Field(default_factory=dict)
@ -64,7 +64,7 @@ class BaseMessage(Serializable):
return id_value
def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
) -> None:
"""Pass in content as positional arg.
@ -85,7 +85,7 @@ class BaseMessage(Serializable):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].
"""
@ -119,9 +119,9 @@ class BaseMessage(Serializable):
def merge_content(
first_content: Union[str, List[Union[str, Dict]]],
*contents: Union[str, List[Union[str, Dict]]],
) -> Union[str, List[Union[str, Dict]]]:
first_content: Union[str, list[Union[str, dict]]],
*contents: Union[str, list[Union[str, dict]]],
) -> Union[str, list[Union[str, dict]]]:
"""Merge two message contents.
Args:
@ -163,7 +163,7 @@ class BaseMessageChunk(BaseMessage):
"""Message chunk, which can be concatenated with other Message chunks."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].
"""
@ -242,7 +242,7 @@ def message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.model_dump()}
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
def messages_to_dict(messages: Sequence[BaseMessage]) -> list[dict]:
"""Convert a sequence of Messages to a list of dictionaries.
Args:

View File

@ -23,7 +23,6 @@ from typing import (
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
overload,
@ -166,7 +165,7 @@ def _message_from_dict(message: dict) -> BaseMessage:
raise ValueError(f"Got unexpected message type: {_type}")
def messages_from_dict(messages: Sequence[dict]) -> List[BaseMessage]:
def messages_from_dict(messages: Sequence[dict]) -> list[BaseMessage]:
"""Convert a sequence of messages from dicts to Message objects.
Args:
@ -208,7 +207,7 @@ def _create_message_from_message_type(
content: str,
name: Optional[str] = None,
tool_call_id: Optional[str] = None,
tool_calls: Optional[List[Dict[str, Any]]] = None,
tool_calls: Optional[list[dict[str, Any]]] = None,
id: Optional[str] = None,
**additional_kwargs: Any,
) -> BaseMessage:
@ -230,7 +229,7 @@ def _create_message_from_message_type(
ValueError: if the message type is not one of "human", "user", "ai",
"assistant", "system", "function", or "tool".
"""
kwargs: Dict[str, Any] = {}
kwargs: dict[str, Any] = {}
if name is not None:
kwargs["name"] = name
if tool_call_id is not None:
@ -331,7 +330,7 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
def convert_to_messages(
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
) -> List[BaseMessage]:
) -> list[BaseMessage]:
"""Convert a sequence of messages to a list of messages.
Args:
@ -352,18 +351,18 @@ def _runnable_support(func: Callable) -> Callable:
@overload
def wrapped(
messages: Literal[None] = None, **kwargs: Any
) -> Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]]: ...
) -> Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]: ...
@overload
def wrapped(
messages: Sequence[MessageLikeRepresentation], **kwargs: Any
) -> List[BaseMessage]: ...
) -> list[BaseMessage]: ...
def wrapped(
messages: Optional[Sequence[MessageLikeRepresentation]] = None, **kwargs: Any
) -> Union[
List[BaseMessage],
Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]],
list[BaseMessage],
Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]],
]:
from langchain_core.runnables.base import RunnableLambda
@ -382,11 +381,11 @@ def filter_messages(
*,
include_names: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[Union[str, Type[BaseMessage]]]] = None,
exclude_types: Optional[Sequence[Union[str, Type[BaseMessage]]]] = None,
include_types: Optional[Sequence[Union[str, type[BaseMessage]]]] = None,
exclude_types: Optional[Sequence[Union[str, type[BaseMessage]]]] = None,
include_ids: Optional[Sequence[str]] = None,
exclude_ids: Optional[Sequence[str]] = None,
) -> List[BaseMessage]:
) -> list[BaseMessage]:
"""Filter messages based on name, type or id.
Args:
@ -438,7 +437,7 @@ def filter_messages(
]
""" # noqa: E501
messages = convert_to_messages(messages)
filtered: List[BaseMessage] = []
filtered: list[BaseMessage] = []
for msg in messages:
if exclude_names and msg.name in exclude_names:
continue
@ -469,7 +468,7 @@ def merge_message_runs(
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
*,
chunk_separator: str = "\n",
) -> List[BaseMessage]:
) -> list[BaseMessage]:
"""Merge consecutive Messages of the same type.
**NOTE**: ToolMessages are not merged, as each has a distinct tool call id that
@ -539,7 +538,7 @@ def merge_message_runs(
if not messages:
return []
messages = convert_to_messages(messages)
merged: List[BaseMessage] = []
merged: list[BaseMessage] = []
for msg in messages:
curr = msg.model_copy(deep=True)
last = merged.pop() if merged else None
@ -569,21 +568,21 @@ def trim_messages(
*,
max_tokens: int,
token_counter: Union[
Callable[[List[BaseMessage]], int],
Callable[[list[BaseMessage]], int],
Callable[[BaseMessage], int],
BaseLanguageModel,
],
strategy: Literal["first", "last"] = "last",
allow_partial: bool = False,
end_on: Optional[
Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]]
Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]]
] = None,
start_on: Optional[
Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]]
Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]]
] = None,
include_system: bool = False,
text_splitter: Optional[Union[Callable[[str], List[str]], TextSplitter]] = None,
) -> List[BaseMessage]:
text_splitter: Optional[Union[Callable[[str], list[str]], TextSplitter]] = None,
) -> list[BaseMessage]:
"""Trim messages to be below a token count.
Args:
@ -875,13 +874,13 @@ def _first_max_tokens(
messages: Sequence[BaseMessage],
*,
max_tokens: int,
token_counter: Callable[[List[BaseMessage]], int],
text_splitter: Callable[[str], List[str]],
token_counter: Callable[[list[BaseMessage]], int],
text_splitter: Callable[[str], list[str]],
partial_strategy: Optional[Literal["first", "last"]] = None,
end_on: Optional[
Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]]
Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]]
] = None,
) -> List[BaseMessage]:
) -> list[BaseMessage]:
messages = list(messages)
idx = 0
for i in range(len(messages)):
@ -949,17 +948,17 @@ def _last_max_tokens(
messages: Sequence[BaseMessage],
*,
max_tokens: int,
token_counter: Callable[[List[BaseMessage]], int],
text_splitter: Callable[[str], List[str]],
token_counter: Callable[[list[BaseMessage]], int],
text_splitter: Callable[[str], list[str]],
allow_partial: bool = False,
include_system: bool = False,
start_on: Optional[
Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]]
Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]]
] = None,
end_on: Optional[
Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]]
Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]]
] = None,
) -> List[BaseMessage]:
) -> list[BaseMessage]:
messages = list(messages)
if end_on:
while messages and not _is_message_type(messages[-1], end_on):
@ -984,7 +983,7 @@ def _last_max_tokens(
return reversed_[::-1]
_MSG_CHUNK_MAP: Dict[Type[BaseMessage], Type[BaseMessageChunk]] = {
_MSG_CHUNK_MAP: dict[type[BaseMessage], type[BaseMessageChunk]] = {
HumanMessage: HumanMessageChunk,
AIMessage: AIMessageChunk,
SystemMessage: SystemMessageChunk,
@ -1024,14 +1023,14 @@ def _chunk_to_msg(chunk: BaseMessageChunk) -> BaseMessage:
)
def _default_text_splitter(text: str) -> List[str]:
def _default_text_splitter(text: str) -> list[str]:
splits = text.split("\n")
return [s + "\n" for s in splits[:-1]] + splits[-1:]
def _is_message_type(
message: BaseMessage,
type_: Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]],
type_: Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]],
) -> bool:
types = [type_] if isinstance(type_, (str, type)) else type_
types_str = [t for t in types if isinstance(t, str)]

View File

@ -4,11 +4,8 @@ from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Optional,
Type,
TypeVar,
Union,
)
@ -30,7 +27,7 @@ class BaseLLMOutputParser(Generic[T], ABC):
"""Abstract base class for parsing the outputs of a model."""
@abstractmethod
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
"""Parse a list of candidate model Generations into a specific format.
Args:
@ -44,7 +41,7 @@ class BaseLLMOutputParser(Generic[T], ABC):
"""
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
self, result: list[Generation], *, partial: bool = False
) -> T:
"""Async parse a list of candidate model Generations into a specific format.
@ -71,7 +68,7 @@ class BaseGenerationOutputParser(
return Union[str, AnyMessage]
@property
def OutputType(self) -> Type[T]:
def OutputType(self) -> type[T]:
"""Return the output type for the parser."""
# even though mypy complains this isn't valid,
# it is good enough for pydantic to build the schema from
@ -156,7 +153,7 @@ class BaseOutputParser(
return Union[str, AnyMessage]
@property
def OutputType(self) -> Type[T]:
def OutputType(self) -> type[T]:
"""Return the output type for the parser.
This property is inferred from the first type argument of the class.
@ -218,7 +215,7 @@ class BaseOutputParser(
run_type="parser",
)
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
@ -247,7 +244,7 @@ class BaseOutputParser(
"""
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
self, result: list[Generation], *, partial: bool = False
) -> T:
"""Async parse a list of candidate model Generations into a specific format.
@ -305,7 +302,7 @@ class BaseOutputParser(
" This is required for serialization."
)
def dict(self, **kwargs: Any) -> Dict:
def dict(self, **kwargs: Any) -> dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict(**kwargs)
try:

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import json
from json import JSONDecodeError
from typing import Any, List, Optional, Type, TypeVar, Union
from typing import Any, Optional, TypeVar, Union
import jsonpatch # type: ignore[import]
import pydantic
@ -42,14 +42,14 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
describing the difference between the previous and the current object.
"""
pydantic_object: Annotated[Optional[Type[TBaseModel]], SkipValidation()] = None # type: ignore
pydantic_object: Annotated[Optional[type[TBaseModel]], SkipValidation()] = None # type: ignore
"""The Pydantic object to use for validation.
If None, no validation is performed."""
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
def _get_schema(self, pydantic_object: Type[TBaseModel]) -> dict[str, Any]:
def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]:
if PYDANTIC_MAJOR_VERSION == 2:
if issubclass(pydantic_object, pydantic.BaseModel):
return pydantic_object.model_json_schema()
@ -57,7 +57,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
return pydantic_object.model_json_schema()
return pydantic_object.model_json_schema()
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import re
from abc import abstractmethod
from collections import deque
from typing import AsyncIterator, Deque, Iterator, List, TypeVar, Union
from typing import AsyncIterator, Iterator, List, TypeVar, Union
from typing import Optional as Optional
from langchain_core.messages import BaseMessage
@ -22,7 +22,7 @@ def droplastn(iter: Iterator[T], n: int) -> Iterator[T]:
Yields:
The elements of the iterator, except the last n elements.
"""
buffer: Deque[T] = deque()
buffer: deque[T] = deque()
for item in iter:
buffer.append(item)
if len(buffer) > n:
@ -37,7 +37,7 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
return "list"
@abstractmethod
def parse(self, text: str) -> List[str]:
def parse(self, text: str) -> list[str]:
"""Parse the output of an LLM call.
Args:
@ -60,7 +60,7 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[List[str]]:
) -> Iterator[list[str]]:
buffer = ""
for chunk in input:
if isinstance(chunk, BaseMessage):
@ -92,7 +92,7 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[List[str]]:
) -> AsyncIterator[list[str]]:
buffer = ""
async for chunk in input:
if isinstance(chunk, BaseMessage):
@ -136,7 +136,7 @@ class CommaSeparatedListOutputParser(ListOutputParser):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Returns:
@ -152,7 +152,7 @@ class CommaSeparatedListOutputParser(ListOutputParser):
"eg: `foo, bar, baz` or `foo,bar,baz`"
)
def parse(self, text: str) -> List[str]:
def parse(self, text: str) -> list[str]:
"""Parse the output of an LLM call.
Args:
@ -180,7 +180,7 @@ class NumberedListOutputParser(ListOutputParser):
"For example: \n\n1. foo\n\n2. bar\n\n3. baz"
)
def parse(self, text: str) -> List[str]:
def parse(self, text: str) -> list[str]:
"""Parse the output of an LLM call.
Args:
@ -217,7 +217,7 @@ class MarkdownListOutputParser(ListOutputParser):
"""Return the format instructions for the Markdown list output."""
return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`"
def parse(self, text: str) -> List[str]:
def parse(self, text: str) -> list[str]:
"""Parse the output of an LLM call.
Args:

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import List, Literal, Union
from typing import Literal, Union
from pydantic import model_validator
from typing_extensions import Self
@ -69,7 +69,7 @@ class ChatGeneration(Generation):
return self
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "output"]
@ -86,12 +86,12 @@ class ChatGenerationChunk(ChatGeneration):
"""Type is used exclusively for serialization purposes."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "output"]
def __add__(
self, other: Union[ChatGenerationChunk, List[ChatGenerationChunk]]
self, other: Union[ChatGenerationChunk, list[ChatGenerationChunk]]
) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk):
generation_info = merge_dicts(

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Literal, Optional
from langchain_core.load import Serializable
from langchain_core.utils._merge import merge_dicts
@ -25,7 +25,7 @@ class Generation(Serializable):
text: str
"""Generated text output."""
generation_info: Optional[Dict[str, Any]] = None
generation_info: Optional[dict[str, Any]] = None
"""Raw response from the provider.
May include things like the reason for finishing or token log probabilities.
@ -40,7 +40,7 @@ class Generation(Serializable):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "output"]
@ -49,7 +49,7 @@ class GenerationChunk(Generation):
"""Generation chunk, which can be concatenated with other Generation chunks."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "output"]

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from copy import deepcopy
from typing import List, Literal, Optional, Union
from typing import Literal, Optional, Union
from pydantic import BaseModel
@ -18,8 +18,8 @@ class LLMResult(BaseModel):
wants to return.
"""
generations: List[
List[Union[Generation, ChatGeneration, GenerationChunk, ChatGenerationChunk]]
generations: list[
list[Union[Generation, ChatGeneration, GenerationChunk, ChatGenerationChunk]]
]
"""Generated outputs.
@ -45,13 +45,13 @@ class LLMResult(BaseModel):
accessing relevant information from standardized fields present in
AIMessage.
"""
run: Optional[List[RunInfo]] = None
run: Optional[list[RunInfo]] = None
"""List of metadata info for model call for each input."""
type: Literal["LLMResult"] = "LLMResult" # type: ignore[assignment]
"""Type is used exclusively for serialization purposes."""
def flatten(self) -> List[LLMResult]:
def flatten(self) -> list[LLMResult]:
"""Flatten generations into a single list.
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult

View File

@ -7,7 +7,7 @@ They can be used to represent text, images, or chat message pieces.
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Literal, Sequence, cast
from typing import Literal, Sequence, cast
from typing_extensions import TypedDict
@ -33,7 +33,7 @@ class PromptValue(Serializable, ABC):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
This is used to determine the namespace of the object when serializing.
Defaults to ["langchain", "schema", "prompt"].
@ -45,7 +45,7 @@ class PromptValue(Serializable, ABC):
"""Return prompt value as string."""
@abstractmethod
def to_messages(self) -> List[BaseMessage]:
def to_messages(self) -> list[BaseMessage]:
"""Return prompt as a list of Messages."""
@ -57,7 +57,7 @@ class StringPromptValue(PromptValue):
type: Literal["StringPromptValue"] = "StringPromptValue"
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
This is used to determine the namespace of the object when serializing.
Defaults to ["langchain", "prompts", "base"].
@ -68,7 +68,7 @@ class StringPromptValue(PromptValue):
"""Return prompt as string."""
return self.text
def to_messages(self) -> List[BaseMessage]:
def to_messages(self) -> list[BaseMessage]:
"""Return prompt as messages."""
return [HumanMessage(content=self.text)]
@ -86,12 +86,12 @@ class ChatPromptValue(PromptValue):
"""Return prompt as string."""
return get_buffer_string(self.messages)
def to_messages(self) -> List[BaseMessage]:
def to_messages(self) -> list[BaseMessage]:
"""Return prompt as a list of messages."""
return list(self.messages)
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
This is used to determine the namespace of the object when serializing.
Defaults to ["langchain", "prompts", "chat"].
@ -121,7 +121,7 @@ class ImagePromptValue(PromptValue):
"""Return prompt (image URL) as string."""
return self.image_url["url"]
def to_messages(self) -> List[BaseMessage]:
def to_messages(self) -> list[BaseMessage]:
"""Return prompt (image URL) as messages."""
return [HumanMessage(content=[cast(dict, self.image_url)])]
@ -136,7 +136,7 @@ class ChatPromptValueConcrete(ChatPromptValue):
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
This is used to determine the namespace of the object when serializing.
Defaults to ["langchain", "prompts", "chat"].

View File

@ -10,10 +10,8 @@ from typing import (
Callable,
Dict,
Generic,
List,
Mapping,
Optional,
Type,
TypeVar,
Union,
)
@ -45,14 +43,14 @@ class BasePromptTemplate(
):
"""Base class for all prompt templates, returning a prompt."""
input_variables: List[str]
input_variables: list[str]
"""A list of the names of the variables whose values are required as inputs to the
prompt."""
optional_variables: List[str] = Field(default=[])
optional_variables: list[str] = Field(default=[])
"""optional_variables: A list of the names of the variables for placeholder
or MessagePlaceholder that are optional. These variables are auto inferred
from the prompt and user need not provide them."""
input_types: Dict[str, Any] = Field(default_factory=dict, exclude=True)
input_types: Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006
"""A dictionary of the types of the variables the prompt template expects.
If not provided, all variables are assumed to be strings."""
output_parser: Optional[BaseOutputParser] = None
@ -62,9 +60,9 @@ class BasePromptTemplate(
Partial variables populate the template so that you don't need to
pass them in every time you call the prompt."""
metadata: Optional[Dict[str, Any]] = None
metadata: Optional[Dict[str, Any]] = None # noqa: UP006
"""Metadata to be used for tracing."""
tags: Optional[List[str]] = None
tags: Optional[list[str]] = None
"""Tags to be used for tracing."""
@model_validator(mode="after")
@ -89,7 +87,7 @@ class BasePromptTemplate(
return self
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Returns ["langchain", "schema", "prompt_template"]."""
return ["langchain", "schema", "prompt_template"]
@ -115,7 +113,7 @@ class BasePromptTemplate(
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""Get the input schema for the prompt.
Args:
@ -136,7 +134,7 @@ class BasePromptTemplate(
field_definitions={**required_input_variables, **optional_input_variables},
)
def _validate_input(self, inner_input: Any) -> Dict:
def _validate_input(self, inner_input: Any) -> dict:
if not isinstance(inner_input, dict):
if len(self.input_variables) == 1:
var_name = self.input_variables[0]
@ -163,18 +161,18 @@ class BasePromptTemplate(
raise KeyError(msg)
return inner_input
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:
def _format_prompt_with_error_handling(self, inner_input: dict) -> PromptValue:
_inner_input = self._validate_input(inner_input)
return self.format_prompt(**_inner_input)
async def _aformat_prompt_with_error_handling(
self, inner_input: Dict
self, inner_input: dict
) -> PromptValue:
_inner_input = self._validate_input(inner_input)
return await self.aformat_prompt(**_inner_input)
def invoke(
self, input: Dict, config: Optional[RunnableConfig] = None
self, input: dict, config: Optional[RunnableConfig] = None
) -> PromptValue:
"""Invoke the prompt.
@ -199,7 +197,7 @@ class BasePromptTemplate(
)
async def ainvoke(
self, input: Dict, config: Optional[RunnableConfig] = None, **kwargs: Any
self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> PromptValue:
"""Async invoke the prompt.
@ -261,7 +259,7 @@ class BasePromptTemplate(
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
def _merge_partial_and_user_variables(self, **kwargs: Any) -> dict[str, Any]:
# Get partial params:
partial_kwargs = {
k: v if not callable(v) else v() for k, v in self.partial_variables.items()
@ -307,7 +305,7 @@ class BasePromptTemplate(
"""Return the prompt type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
def dict(self, **kwargs: Any) -> dict:
"""Return dictionary representation of prompt.
Args:
@ -369,7 +367,7 @@ class BasePromptTemplate(
raise ValueError(f"{save_path} must be json or yaml")
def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> Dict:
def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> dict:
base_info = {"page_content": doc.page_content, **doc.metadata}
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:

View File

@ -6,12 +6,10 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
Type,
TypedDict,
@ -60,12 +58,12 @@ class BaseMessagePromptTemplate(Serializable, ABC):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format messages from kwargs. Should return a list of BaseMessages.
Args:
@ -75,7 +73,7 @@ class BaseMessagePromptTemplate(Serializable, ABC):
List of BaseMessages.
"""
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Async format messages from kwargs.
Should return a list of BaseMessages.
@ -89,7 +87,7 @@ class BaseMessagePromptTemplate(Serializable, ABC):
@property
@abstractmethod
def input_variables(self) -> List[str]:
def input_variables(self) -> list[str]:
"""Input variables for this prompt template.
Returns:
@ -210,7 +208,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
Defaults to None."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
@ -223,7 +221,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
variable_name=variable_name, optional=optional, **kwargs
)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format messages from kwargs.
Args:
@ -251,7 +249,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
return value
@property
def input_variables(self) -> List[str]:
def input_variables(self) -> list[str]:
"""Input variables for this prompt template.
Returns:
@ -292,16 +290,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
"""Additional keyword arguments to pass to the prompt template."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
@classmethod
def from_template(
cls: Type[MessagePromptTemplateT],
cls: type[MessagePromptTemplateT],
template: str,
template_format: str = "f-string",
partial_variables: Optional[Dict[str, Any]] = None,
partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> MessagePromptTemplateT:
"""Create a class from a string template.
@ -329,9 +327,9 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
@classmethod
def from_template_file(
cls: Type[MessagePromptTemplateT],
cls: type[MessagePromptTemplateT],
template_file: Union[str, Path],
input_variables: List[str],
input_variables: list[str],
**kwargs: Any,
) -> MessagePromptTemplateT:
"""Create a class from a template file.
@ -369,7 +367,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
"""
return self.format(**kwargs)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format messages from kwargs.
Args:
@ -380,7 +378,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
"""
return [self.format(**kwargs)]
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Async format messages from kwargs.
Args:
@ -392,7 +390,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
return [await self.aformat(**kwargs)]
@property
def input_variables(self) -> List[str]:
def input_variables(self) -> list[str]:
"""
Input variables for this prompt template.
@ -423,7 +421,7 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Role of the message."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
@ -462,37 +460,37 @@ _StringImageMessagePromptTemplateT = TypeVar(
class _TextTemplateParam(TypedDict, total=False):
text: Union[str, Dict]
text: Union[str, dict]
class _ImageTemplateParam(TypedDict, total=False):
image_url: Union[str, Dict]
image_url: Union[str, dict]
class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""
prompt: Union[
StringPromptTemplate, List[Union[StringPromptTemplate, ImagePromptTemplate]]
StringPromptTemplate, list[Union[StringPromptTemplate, ImagePromptTemplate]]
]
"""Prompt template."""
additional_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the prompt template."""
_msg_class: Type[BaseMessage]
_msg_class: type[BaseMessage]
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
@classmethod
def from_template(
cls: Type[_StringImageMessagePromptTemplateT],
template: Union[str, List[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
cls: type[_StringImageMessagePromptTemplateT],
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template_format: str = "f-string",
*,
partial_variables: Optional[Dict[str, Any]] = None,
partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
"""Create a class from a string template.
@ -511,7 +509,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
ValueError: If the template is not a string or list of strings.
"""
if isinstance(template, str):
prompt: Union[StringPromptTemplate, List] = PromptTemplate.from_template(
prompt: Union[StringPromptTemplate, list] = PromptTemplate.from_template(
template,
template_format=template_format,
partial_variables=partial_variables,
@ -574,9 +572,9 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
@classmethod
def from_template_file(
cls: Type[_StringImageMessagePromptTemplateT],
cls: type[_StringImageMessagePromptTemplateT],
template_file: Union[str, Path],
input_variables: List[str],
input_variables: list[str],
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
"""Create a class from a template file.
@ -593,7 +591,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
template = f.read()
return cls.from_template(template, input_variables=input_variables, **kwargs)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format messages from kwargs.
Args:
@ -604,7 +602,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
"""
return [self.format(**kwargs)]
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Async format messages from kwargs.
Args:
@ -616,7 +614,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
return [await self.aformat(**kwargs)]
@property
def input_variables(self) -> List[str]:
def input_variables(self) -> list[str]:
"""
Input variables for this prompt template.
@ -642,7 +640,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
content=text, additional_kwargs=self.additional_kwargs
)
else:
content: List = []
content: list = []
for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate):
@ -670,7 +668,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
content=text, additional_kwargs=self.additional_kwargs
)
else:
content: List = []
content: list = []
for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate):
@ -703,16 +701,16 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
class HumanMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""
_msg_class: Type[BaseMessage] = HumanMessage
_msg_class: type[BaseMessage] = HumanMessage
class AIMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""AI message prompt template. This is a message sent from the AI."""
_msg_class: Type[BaseMessage] = AIMessage
_msg_class: type[BaseMessage] = AIMessage
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
@ -722,10 +720,10 @@ class SystemMessagePromptTemplate(_StringImageMessagePromptTemplate):
This is a message that is not sent to the user.
"""
_msg_class: Type[BaseMessage] = SystemMessage
_msg_class: type[BaseMessage] = SystemMessage
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
@ -734,7 +732,7 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
"""Base class for chat prompt templates."""
@property
def lc_attributes(self) -> Dict:
def lc_attributes(self) -> dict:
"""
Return a list of attribute names that should be included in the
serialized kwargs. These attributes must be accepted by the
@ -791,10 +789,10 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
return ChatPromptValue(messages=messages)
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format kwargs into a list of messages."""
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Async format kwargs into a list of messages."""
return self.format_messages(**kwargs)
@ -935,7 +933,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
""" # noqa: E501
messages: Annotated[List[MessageLike], SkipValidation()]
messages: Annotated[list[MessageLike], SkipValidation()]
"""List of messages consisting of either message prompt templates or messages."""
validate_template: bool = False
"""Whether or not to try validating the template."""
@ -999,9 +997,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
]
# Automatically infer input variables from messages
input_vars: Set[str] = set()
optional_variables: Set[str] = set()
partial_vars: Dict[str, Any] = {}
input_vars: set[str] = set()
optional_variables: set[str] = set()
partial_vars: dict[str, Any] = {}
for _message in _messages:
if isinstance(_message, MessagesPlaceholder) and _message.optional:
partial_vars[_message.variable_name] = []
@ -1022,7 +1020,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
cast(Type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs)
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
@ -1071,7 +1069,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
messages = values["messages"]
input_vars = set()
optional_variables = set()
input_types: Dict[str, Any] = values.get("input_types", {})
input_types: dict[str, Any] = values.get("input_types", {})
for message in messages:
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
input_vars.update(message.input_variables)
@ -1125,7 +1123,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
@classmethod
@deprecated("0.0.1", alternative="from_messages classmethod", pending=True)
def from_role_strings(
cls, string_messages: List[Tuple[str, str]]
cls, string_messages: list[tuple[str, str]]
) -> ChatPromptTemplate:
"""Create a chat prompt template from a list of (role, template) tuples.
@ -1145,7 +1143,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
@classmethod
@deprecated("0.0.1", alternative="from_messages classmethod", pending=True)
def from_strings(
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]]
cls, string_messages: list[tuple[type[BaseMessagePromptTemplate], str]]
) -> ChatPromptTemplate:
"""Create a chat prompt template from a list of (role class, template) tuples.
@ -1200,7 +1198,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
"""
return cls(messages, template_format=template_format)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format the chat template into a list of finalized messages.
Args:
@ -1224,7 +1222,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
raise ValueError(f"Unexpected input: {message_template}")
return result
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Async format the chat template into a list of finalized messages.
Args:

View File

@ -3,7 +3,7 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, Literal, Optional, Union
from pydantic import (
BaseModel,
@ -31,7 +31,7 @@ from langchain_core.prompts.string import (
class _FewShotPromptTemplateMixin(BaseModel):
"""Prompt template that contains few shot examples."""
examples: Optional[List[dict]] = None
examples: Optional[list[dict]] = None
"""Examples to format into the prompt.
Either this or example_selector should be provided."""
@ -46,7 +46,7 @@ class _FewShotPromptTemplateMixin(BaseModel):
@model_validator(mode="before")
@classmethod
def check_examples_and_selector(cls, values: Dict) -> Any:
def check_examples_and_selector(cls, values: dict) -> Any:
"""Check that one and only one of examples/example_selector are provided.
Args:
@ -73,7 +73,7 @@ class _FewShotPromptTemplateMixin(BaseModel):
return values
def _get_examples(self, **kwargs: Any) -> List[dict]:
def _get_examples(self, **kwargs: Any) -> list[dict]:
"""Get the examples to use for formatting the prompt.
Args:
@ -94,7 +94,7 @@ class _FewShotPromptTemplateMixin(BaseModel):
"One of 'examples' and 'example_selector' should be provided"
)
async def _aget_examples(self, **kwargs: Any) -> List[dict]:
async def _aget_examples(self, **kwargs: Any) -> list[dict]:
"""Async get the examples to use for formatting the prompt.
Args:
@ -363,7 +363,7 @@ class FewShotChatMessagePromptTemplate(
chain.invoke({"input": "What's 3+3?"})
"""
input_variables: List[str] = Field(default_factory=list)
input_variables: list[str] = Field(default_factory=list)
"""A list of the names of the variables the prompt template will use
to pass to the example_selector, if provided."""
@ -380,7 +380,7 @@ class FewShotChatMessagePromptTemplate(
extra="forbid",
)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format kwargs into a list of messages.
Args:
@ -402,7 +402,7 @@ class FewShotChatMessagePromptTemplate(
]
return messages
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Async format kwargs into a list of messages.
Args:

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import warnings
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, Literal, Optional, Union
from pydantic import BaseModel, model_validator
@ -54,13 +54,13 @@ class PromptTemplate(StringPromptTemplate):
"""
@property
def lc_attributes(self) -> Dict[str, Any]:
def lc_attributes(self) -> dict[str, Any]:
return {
"template_format": self.template_format,
}
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "prompt"]
@ -76,7 +76,7 @@ class PromptTemplate(StringPromptTemplate):
@model_validator(mode="before")
@classmethod
def pre_init_validation(cls, values: Dict) -> Any:
def pre_init_validation(cls, values: dict) -> Any:
"""Check that template and input variables are consistent."""
if values.get("template") is None:
# Will let pydantic fail with a ValidationError if template
@ -183,9 +183,9 @@ class PromptTemplate(StringPromptTemplate):
@classmethod
def from_examples(
cls,
examples: List[str],
examples: list[str],
suffix: str,
input_variables: List[str],
input_variables: list[str],
example_separator: str = "\n\n",
prefix: str = "",
**kwargs: Any,
@ -215,7 +215,7 @@ class PromptTemplate(StringPromptTemplate):
def from_file(
cls,
template_file: Union[str, Path],
input_variables: Optional[List[str]] = None,
input_variables: Optional[list[str]] = None,
encoding: Optional[str] = None,
**kwargs: Any,
) -> PromptTemplate:
@ -249,7 +249,7 @@ class PromptTemplate(StringPromptTemplate):
template: str,
*,
template_format: str = "f-string",
partial_variables: Optional[Dict[str, Any]] = None,
partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> PromptTemplate:
"""Load a prompt template from a template.

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import warnings
from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict, List, Set, Tuple, Type
from typing import Any, Callable, Dict
from pydantic import BaseModel, create_model
@ -60,7 +60,7 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
return SandboxedEnvironment().from_string(template).render(**kwargs)
def validate_jinja2(template: str, input_variables: List[str]) -> None:
def validate_jinja2(template: str, input_variables: list[str]) -> None:
"""
Validate that the input variables are valid for the template.
Issues a warning if missing or extra variables are found.
@ -85,7 +85,7 @@ def validate_jinja2(template: str, input_variables: List[str]) -> None:
warnings.warn(warning_message.strip(), stacklevel=7)
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
def _get_jinja2_variables_from_template(template: str) -> set[str]:
try:
from jinja2 import Environment, meta
except ImportError as e:
@ -114,7 +114,7 @@ def mustache_formatter(template: str, **kwargs: Any) -> str:
def mustache_template_vars(
template: str,
) -> Set[str]:
) -> set[str]:
"""Get the variables from a mustache template.
Args:
@ -123,7 +123,7 @@ def mustache_template_vars(
Returns:
The variables from the template.
"""
vars: Set[str] = set()
vars: set[str] = set()
section_depth = 0
for type, key in mustache.tokenize(template):
if type == "end":
@ -144,7 +144,7 @@ Defs = Dict[str, "Defs"]
def mustache_schema(
template: str,
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""Get the variables from a mustache template.
Args:
@ -154,8 +154,8 @@ def mustache_schema(
The variables from the template as a Pydantic model.
"""
fields = {}
prefix: Tuple[str, ...] = ()
section_stack: List[Tuple[str, ...]] = []
prefix: tuple[str, ...] = ()
section_stack: list[tuple[str, ...]] = []
for type, key in mustache.tokenize(template):
if key == ".":
continue
@ -178,7 +178,7 @@ def mustache_schema(
return _create_model_recursive("PromptInput", defs)
def _create_model_recursive(name: str, defs: Defs) -> Type:
def _create_model_recursive(name: str, defs: Defs) -> type:
return create_model( # type: ignore[call-overload]
name,
**{
@ -188,20 +188,20 @@ def _create_model_recursive(name: str, defs: Defs) -> Type:
)
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
DEFAULT_FORMATTER_MAPPING: dict[str, Callable] = {
"f-string": formatter.format,
"mustache": mustache_formatter,
"jinja2": jinja2_formatter,
}
DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
DEFAULT_VALIDATOR_MAPPING: dict[str, Callable] = {
"f-string": formatter.validate_input_variables,
"jinja2": validate_jinja2,
}
def check_valid_template(
template: str, template_format: str, input_variables: List[str]
template: str, template_format: str, input_variables: list[str]
) -> None:
"""Check that template string is valid.
@ -230,7 +230,7 @@ def check_valid_template(
) from exc
def get_template_variables(template: str, template_format: str) -> List[str]:
def get_template_variables(template: str, template_format: str) -> list[str]:
"""Get the variables from the template.
Args:
@ -262,7 +262,7 @@ class StringPromptTemplate(BasePromptTemplate, ABC):
"""String prompt that exposes the format method, returning a prompt."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "base"]

View File

@ -24,7 +24,7 @@ from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, List, Optional
from pydantic import ConfigDict
from typing_extensions import TypedDict
@ -132,14 +132,14 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
_new_arg_supported: bool = False
_expects_other_args: bool = False
tags: Optional[List[str]] = None
tags: Optional[list[str]] = None
"""Optional list of tags associated with the retriever. Defaults to None.
These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a retriever with its
use case.
"""
metadata: Optional[Dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
"""Optional metadata associated with the retriever. Defaults to None.
This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
@ -200,7 +200,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Document]:
) -> list[Document]:
"""Invoke the retriever to get relevant documents.
Main entry point for synchronous retriever invocations.
@ -263,7 +263,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
input: str,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
"""Asynchronously invoke the retriever to get relevant documents.
Main entry point for asynchronous retriever invocations.
@ -324,7 +324,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
@abstractmethod
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
) -> list[Document]:
"""Get documents relevant to a query.
Args:
@ -336,7 +336,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
) -> list[Document]:
"""Asynchronously get documents relevant to a query.
Args:
@ -358,11 +358,11 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
"""Retrieve documents relevant to a query.
Users should favor using `.invoke` or `.batch` rather than
@ -402,11 +402,11 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
"""Asynchronously get documents relevant to a query.
Users should favor using `.ainvoke` or `.abatch` rather than

View File

@ -27,8 +27,6 @@ from typing import (
Optional,
Protocol,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
@ -273,7 +271,7 @@ class Runnable(Generic[Input, Output], ABC):
return name_
@property
def InputType(self) -> Type[Input]:
def InputType(self) -> type[Input]:
"""The type of input this Runnable accepts specified as a type annotation."""
# First loop through all parent classes and if any of them is
# a pydantic model, we will pick up the generic parameterization
@ -298,7 +296,7 @@ class Runnable(Generic[Input, Output], ABC):
)
@property
def OutputType(self) -> Type[Output]:
def OutputType(self) -> type[Output]:
"""The type of output this Runnable produces specified as a type annotation."""
# First loop through bases -- this will help generic
# any pydantic models.
@ -319,13 +317,13 @@ class Runnable(Generic[Input, Output], ABC):
)
@property
def input_schema(self) -> Type[BaseModel]:
def input_schema(self) -> type[BaseModel]:
"""The type of input this Runnable accepts specified as a pydantic model."""
return self.get_input_schema()
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""Get a pydantic model that can be used to validate input to the Runnable.
Runnables that leverage the configurable_fields and configurable_alternatives
@ -360,7 +358,7 @@ class Runnable(Generic[Input, Output], ABC):
def get_input_jsonschema(
self, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Get a JSON schema that represents the input to the Runnable.
Args:
@ -387,13 +385,13 @@ class Runnable(Generic[Input, Output], ABC):
return self.get_input_schema(config).model_json_schema()
@property
def output_schema(self) -> Type[BaseModel]:
def output_schema(self) -> type[BaseModel]:
"""The type of output this Runnable produces specified as a pydantic model."""
return self.get_output_schema()
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""Get a pydantic model that can be used to validate output to the Runnable.
Runnables that leverage the configurable_fields and configurable_alternatives
@ -428,7 +426,7 @@ class Runnable(Generic[Input, Output], ABC):
def get_output_jsonschema(
self, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Get a JSON schema that represents the output of the Runnable.
Args:
@ -455,13 +453,13 @@ class Runnable(Generic[Input, Output], ABC):
return self.get_output_schema(config).model_json_schema()
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
"""List configurable fields for this Runnable."""
return []
def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""The type of config this Runnable accepts specified as a pydantic model.
To mark a field as configurable, see the `configurable_fields`
@ -509,7 +507,7 @@ class Runnable(Generic[Input, Output], ABC):
def get_config_jsonschema(
self, *, include: Optional[Sequence[str]] = None
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Get a JSON schema that represents the output of the Runnable.
Args:
@ -544,7 +542,7 @@ class Runnable(Generic[Input, Output], ABC):
def get_prompts(
self, config: Optional[RunnableConfig] = None
) -> List[BasePromptTemplate]:
) -> list[BasePromptTemplate]:
"""Return a list of prompts used by this Runnable."""
from langchain_core.prompts.base import BasePromptTemplate
@ -614,7 +612,7 @@ class Runnable(Generic[Input, Output], ABC):
"""
return RunnableSequence(self, *others, name=name)
def pick(self, keys: Union[str, List[str]]) -> RunnableSerializable[Any, Any]:
def pick(self, keys: Union[str, list[str]]) -> RunnableSerializable[Any, Any]:
"""Pick keys from the dict output of this Runnable.
Pick single key:
@ -670,11 +668,11 @@ class Runnable(Generic[Input, Output], ABC):
def assign(
self,
**kwargs: Union[
Runnable[Dict[str, Any], Any],
Callable[[Dict[str, Any]], Any],
Runnable[dict[str, Any], Any],
Callable[[dict[str, Any]], Any],
Mapping[
str,
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
Union[Runnable[dict[str, Any], Any], Callable[[dict[str, Any]], Any]],
],
],
) -> RunnableSerializable[Any, Any]:
@ -710,7 +708,7 @@ class Runnable(Generic[Input, Output], ABC):
"""
from langchain_core.runnables.passthrough import RunnableAssign
return self | RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
return self | RunnableAssign(RunnableParallel[dict[str, Any]](kwargs))
""" --- Public API --- """
@ -744,12 +742,12 @@ class Runnable(Generic[Input, Output], ABC):
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
) -> list[Output]:
"""Default implementation runs invoke in parallel using a thread pool executor.
The default implementation of batch works well for IO bound runnables.
@ -786,7 +784,7 @@ class Runnable(Generic[Input, Output], ABC):
*,
return_exceptions: Literal[False] = False,
**kwargs: Any,
) -> Iterator[Tuple[int, Output]]: ...
) -> Iterator[tuple[int, Output]]: ...
@overload
def batch_as_completed(
@ -796,7 +794,7 @@ class Runnable(Generic[Input, Output], ABC):
*,
return_exceptions: Literal[True],
**kwargs: Any,
) -> Iterator[Tuple[int, Union[Output, Exception]]]: ...
) -> Iterator[tuple[int, Union[Output, Exception]]]: ...
def batch_as_completed(
self,
@ -805,7 +803,7 @@ class Runnable(Generic[Input, Output], ABC):
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
) -> Iterator[tuple[int, Union[Output, Exception]]]:
"""Run invoke in parallel on a list of inputs,
yielding results as they complete."""
@ -816,7 +814,7 @@ class Runnable(Generic[Input, Output], ABC):
def invoke(
i: int, input: Input, config: RunnableConfig
) -> Tuple[int, Union[Output, Exception]]:
) -> tuple[int, Union[Output, Exception]]:
if return_exceptions:
try:
out: Union[Output, Exception] = self.invoke(input, config, **kwargs)
@ -848,12 +846,12 @@ class Runnable(Generic[Input, Output], ABC):
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
) -> list[Output]:
"""Default implementation runs ainvoke in parallel using asyncio.gather.
The default implementation of batch works well for IO bound runnables.
@ -902,7 +900,7 @@ class Runnable(Generic[Input, Output], ABC):
*,
return_exceptions: Literal[False] = False,
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Output]]: ...
) -> AsyncIterator[tuple[int, Output]]: ...
@overload
def abatch_as_completed(
@ -912,7 +910,7 @@ class Runnable(Generic[Input, Output], ABC):
*,
return_exceptions: Literal[True],
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: ...
) -> AsyncIterator[tuple[int, Union[Output, Exception]]]: ...
async def abatch_as_completed(
self,
@ -921,7 +919,7 @@ class Runnable(Generic[Input, Output], ABC):
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
) -> AsyncIterator[tuple[int, Union[Output, Exception]]]:
"""Run ainvoke in parallel on a list of inputs,
yielding results as they complete.
@ -947,7 +945,7 @@ class Runnable(Generic[Input, Output], ABC):
async def ainvoke(
i: int, input: Input, config: RunnableConfig
) -> Tuple[int, Union[Output, Exception]]:
) -> tuple[int, Union[Output, Exception]]:
if return_exceptions:
try:
out: Union[Output, Exception] = await self.ainvoke(
@ -1699,8 +1697,8 @@ class Runnable(Generic[Input, Output], ABC):
def with_types(
self,
*,
input_type: Optional[Type[Input]] = None,
output_type: Optional[Type[Output]] = None,
input_type: Optional[type[Input]] = None,
output_type: Optional[type[Output]] = None,
) -> Runnable[Input, Output]:
"""
Bind input and output types to a Runnable, returning a new Runnable.
@ -1722,7 +1720,7 @@ class Runnable(Generic[Input, Output], ABC):
def with_retry(
self,
*,
retry_if_exception_type: Tuple[Type[BaseException], ...] = (Exception,),
retry_if_exception_type: tuple[type[BaseException], ...] = (Exception,),
wait_exponential_jitter: bool = True,
stop_after_attempt: int = 3,
) -> Runnable[Input, Output]:
@ -1789,7 +1787,7 @@ class Runnable(Generic[Input, Output], ABC):
max_attempt_number=stop_after_attempt,
)
def map(self) -> Runnable[List[Input], List[Output]]:
def map(self) -> Runnable[list[Input], list[Output]]:
"""
Return a new Runnable that maps a list of inputs to a list of outputs,
by calling invoke() with each input.
@ -1815,7 +1813,7 @@ class Runnable(Generic[Input, Output], ABC):
self,
fallbacks: Sequence[Runnable[Input, Output]],
*,
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,),
exceptions_to_handle: tuple[type[BaseException], ...] = (Exception,),
exception_key: Optional[str] = None,
) -> RunnableWithFallbacksT[Input, Output]:
"""Add fallbacks to a Runnable, returning a new Runnable.
@ -1893,7 +1891,7 @@ class Runnable(Generic[Input, Output], ABC):
input: Input,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
serialized: Optional[Dict[str, Any]] = None,
serialized: Optional[dict[str, Any]] = None,
**kwargs: Optional[Any],
) -> Output:
"""Helper method to transform an Input value to an Output value,
@ -1942,7 +1940,7 @@ class Runnable(Generic[Input, Output], ABC):
input: Input,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
serialized: Optional[Dict[str, Any]] = None,
serialized: Optional[dict[str, Any]] = None,
**kwargs: Optional[Any],
) -> Output:
"""Helper method to transform an Input value to an Output value,
@ -1977,23 +1975,23 @@ class Runnable(Generic[Input, Output], ABC):
def _batch_with_config(
self,
func: Union[
Callable[[List[Input]], List[Union[Exception, Output]]],
Callable[[list[Input]], list[Union[Exception, Output]]],
Callable[
[List[Input], List[CallbackManagerForChainRun]],
List[Union[Exception, Output]],
[list[Input], list[CallbackManagerForChainRun]],
list[Union[Exception, Output]],
],
Callable[
[List[Input], List[CallbackManagerForChainRun], List[RunnableConfig]],
List[Union[Exception, Output]],
[list[Input], list[CallbackManagerForChainRun], list[RunnableConfig]],
list[Union[Exception, Output]],
],
],
input: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
input: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> List[Output]:
) -> list[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
if not input:
@ -2045,27 +2043,27 @@ class Runnable(Generic[Input, Output], ABC):
async def _abatch_with_config(
self,
func: Union[
Callable[[List[Input]], Awaitable[List[Union[Exception, Output]]]],
Callable[[list[Input]], Awaitable[list[Union[Exception, Output]]]],
Callable[
[List[Input], List[AsyncCallbackManagerForChainRun]],
Awaitable[List[Union[Exception, Output]]],
[list[Input], list[AsyncCallbackManagerForChainRun]],
Awaitable[list[Union[Exception, Output]]],
],
Callable[
[
List[Input],
List[AsyncCallbackManagerForChainRun],
List[RunnableConfig],
list[Input],
list[AsyncCallbackManagerForChainRun],
list[RunnableConfig],
],
Awaitable[List[Union[Exception, Output]]],
Awaitable[list[Union[Exception, Output]]],
],
],
input: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
input: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> List[Output]:
) -> list[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
if not input:
@ -2073,7 +2071,7 @@ class Runnable(Generic[Input, Output], ABC):
configs = get_config_list(config, len(input))
callback_managers = [get_async_callback_manager_for_config(c) for c in configs]
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
callback_manager.on_chain_start(
None,
@ -2106,7 +2104,7 @@ class Runnable(Generic[Input, Output], ABC):
raise
else:
first_exception: Optional[Exception] = None
coros: List[Awaitable[None]] = []
coros: list[Awaitable[None]] = []
for run_manager, out in zip(run_managers, output):
if isinstance(out, Exception):
first_exception = first_exception or out
@ -2333,11 +2331,11 @@ class Runnable(Generic[Input, Output], ABC):
@beta_decorator.beta(message="This API is in beta and may change in the future.")
def as_tool(
self,
args_schema: Optional[Type[BaseModel]] = None,
args_schema: Optional[type[BaseModel]] = None,
*,
name: Optional[str] = None,
description: Optional[str] = None,
arg_types: Optional[Dict[str, Type]] = None,
arg_types: Optional[dict[str, type]] = None,
) -> BaseTool:
"""Create a BaseTool from a Runnable.
@ -2573,8 +2571,8 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
def _seq_input_schema(
steps: List[Runnable[Any, Any]], config: Optional[RunnableConfig]
) -> Type[BaseModel]:
steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig]
) -> type[BaseModel]:
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick
first = steps[0]
@ -2599,8 +2597,8 @@ def _seq_input_schema(
def _seq_output_schema(
steps: List[Runnable[Any, Any]], config: Optional[RunnableConfig]
) -> Type[BaseModel]:
steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig]
) -> type[BaseModel]:
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick
last = steps[-1]
@ -2730,7 +2728,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
# the last type.
first: Runnable[Input, Any]
"""The first Runnable in the sequence."""
middle: List[Runnable[Any, Any]] = Field(default_factory=list)
middle: list[Runnable[Any, Any]] = Field(default_factory=list)
"""The middle Runnables in the sequence."""
last: Runnable[Any, Output]
"""The last Runnable in the sequence."""
@ -2740,7 +2738,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
*steps: RunnableLike,
name: Optional[str] = None,
first: Optional[Runnable[Any, Any]] = None,
middle: Optional[List[Runnable[Any, Any]]] = None,
middle: Optional[list[Runnable[Any, Any]]] = None,
last: Optional[Runnable[Any, Any]] = None,
) -> None:
"""Create a new RunnableSequence.
@ -2755,7 +2753,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
Raises:
ValueError: If the sequence has less than 2 steps.
"""
steps_flat: List[Runnable] = []
steps_flat: list[Runnable] = []
if not steps:
if first is not None and last is not None:
steps_flat = [first] + (middle or []) + [last]
@ -2776,12 +2774,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
)
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property
def steps(self) -> List[Runnable[Any, Any]]:
def steps(self) -> list[Runnable[Any, Any]]:
"""All the Runnables that make up the sequence in order.
Returns:
@ -2804,18 +2802,18 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
)
@property
def InputType(self) -> Type[Input]:
def InputType(self) -> type[Input]:
"""The type of the input to the Runnable."""
return self.first.InputType
@property
def OutputType(self) -> Type[Output]:
def OutputType(self) -> type[Output]:
"""The type of the output of the Runnable."""
return self.last.OutputType
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""Get the input schema of the Runnable.
Args:
@ -2828,7 +2826,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""Get the output schema of the Runnable.
Args:
@ -2840,7 +2838,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
return _seq_output_schema(self.steps, config)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
"""Get the config specs of the Runnable.
Returns:
@ -2862,8 +2860,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
[tup for tup in all_specs if tup[0].id.startswith(CONTEXT_CONFIG_PREFIX)],
lambda x: x[1],
)
next_deps: Set[str] = set()
deps_by_pos: Dict[int, Set[str]] = {}
next_deps: set[str] = set()
deps_by_pos: dict[int, set[str]] = {}
for pos, specs in specs_by_pos:
deps_by_pos[pos] = next_deps
next_deps = next_deps | {spec[0].id for spec in specs}
@ -3065,12 +3063,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
) -> list[Output]:
from langchain_core.beta.runnables.context import config_with_context
from langchain_core.callbacks.manager import CallbackManager
@ -3111,7 +3109,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
# Track which inputs (by index) failed so far
# If an input has failed it will be present in this map,
# and the value will be the exception that was raised.
failed_inputs_map: Dict[int, Exception] = {}
failed_inputs_map: dict[int, Exception] = {}
for stepidx, step in enumerate(self.steps):
# Assemble the original indexes of the remaining inputs
# (i.e. the ones that haven't failed yet)
@ -3191,12 +3189,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
) -> list[Output]:
from langchain_core.beta.runnables.context import aconfig_with_context
from langchain_core.callbacks.manager import AsyncCallbackManager
@ -3221,7 +3219,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
for config in configs
]
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(
None,
@ -3240,7 +3238,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
# Track which inputs (by index) failed so far
# If an input has failed it will be present in this map,
# and the value will be the exception that was raised.
failed_inputs_map: Dict[int, Exception] = {}
failed_inputs_map: dict[int, Exception] = {}
for stepidx, step in enumerate(self.steps):
# Assemble the original indexes of the remaining inputs
# (i.e. the ones that haven't failed yet)
@ -3305,7 +3303,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
raise
else:
first_exception: Optional[Exception] = None
coros: List[Awaitable[None]] = []
coros: list[Awaitable[None]] = []
for run_manager, out in zip(run_managers, inputs):
if isinstance(out, Exception):
first_exception = first_exception or out
@ -3533,7 +3531,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@ -3567,7 +3565,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""Get the input schema of the Runnable.
Args:
@ -3596,7 +3594,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""Get the output schema of the Runnable.
Args:
@ -3609,7 +3607,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
return create_model_v2(self.get_name("Output"), field_definitions=fields)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
"""Get the config specs of the Runnable.
Returns:
@ -3662,7 +3660,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
) -> dict[str, Any]:
from langchain_core.callbacks.manager import CallbackManager
# setup callbacks
@ -3724,7 +3722,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Dict[str, Any]:
) -> dict[str, Any]:
# setup callbacks
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
@ -3829,7 +3827,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
) -> Iterator[dict[str, Any]]:
yield from self._transform_stream_with_config(
input, self._transform, config, **kwargs
)
@ -3839,7 +3837,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Dict[str, Any]]:
) -> Iterator[dict[str, Any]]:
yield from self.transform(iter([input]), config)
async def _atransform(
@ -3898,7 +3896,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
) -> AsyncIterator[dict[str, Any]]:
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, **kwargs
):
@ -3909,7 +3907,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Dict[str, Any]]:
) -> AsyncIterator[dict[str, Any]]:
async def input_aiter() -> AsyncIterator[Input]:
yield input
@ -4064,7 +4062,7 @@ class RunnableGenerator(Runnable[Input, Output]):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
# Override the default implementation.
# For a runnable generator, we need to bring to provide the
# module of the underlying function when creating the model.
@ -4100,7 +4098,7 @@ class RunnableGenerator(Runnable[Input, Output]):
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
# Override the default implementation.
# For a runnable generator, we need to bring to provide the
# module of the underlying function when creating the model.
@ -4346,7 +4344,7 @@ class RunnableLambda(Runnable[Input, Output]):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""The pydantic schema for the input to this Runnable.
Args:
@ -4414,7 +4412,7 @@ class RunnableLambda(Runnable[Input, Output]):
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
# Override the default implementation.
# For a runnable lambda, we need to bring to provide the
# module of the underlying function when creating the model.
@ -4435,7 +4433,7 @@ class RunnableLambda(Runnable[Input, Output]):
)
@property
def deps(self) -> List[Runnable]:
def deps(self) -> list[Runnable]:
"""The dependencies of this Runnable.
Returns:
@ -4449,7 +4447,7 @@ class RunnableLambda(Runnable[Input, Output]):
else:
objects = []
deps: List[Runnable] = []
deps: list[Runnable] = []
for obj in objects:
if isinstance(obj, Runnable):
deps.append(obj)
@ -4458,7 +4456,7 @@ class RunnableLambda(Runnable[Input, Output]):
return deps
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec for dep in self.deps for spec in dep.config_specs
)
@ -4944,7 +4942,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
return create_model_v2(
self.get_name("Input"),
root=(
@ -4962,12 +4960,12 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
)
@property
def OutputType(self) -> Type[List[Output]]:
def OutputType(self) -> type[list[Output]]:
return List[self.bound.OutputType] # type: ignore[name-defined]
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
schema = self.bound.get_output_schema(config)
return create_model_v2(
self.get_name("Output"),
@ -4983,7 +4981,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
return self.bound.config_specs
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
@ -4994,42 +4992,42 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def _invoke(
self,
inputs: List[Input],
inputs: list[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> List[Output]:
) -> list[Output]:
configs = [
patch_config(config, callbacks=run_manager.get_child()) for _ in inputs
]
return self.bound.batch(inputs, configs, **kwargs)
def invoke(
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Output]:
self, input: list[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
) -> list[Output]:
return self._call_with_config(self._invoke, input, config, **kwargs)
async def _ainvoke(
self,
inputs: List[Input],
inputs: list[Input],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> List[Output]:
) -> list[Output]:
configs = [
patch_config(config, callbacks=run_manager.get_child()) for _ in inputs
]
return await self.bound.abatch(inputs, configs, **kwargs)
async def ainvoke(
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Output]:
self, input: list[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
) -> list[Output]:
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
async def astream_events(
@ -5074,7 +5072,7 @@ class RunnableEach(RunnableEachBase[Input, Output]):
"""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@ -5181,7 +5179,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
config: RunnableConfig = Field(default_factory=dict)
"""The config to bind to the underlying Runnable."""
config_factories: List[Callable[[RunnableConfig], RunnableConfig]] = Field(
config_factories: list[Callable[[RunnableConfig], RunnableConfig]] = Field(
default_factory=list
)
"""The config factories to bind to the underlying Runnable."""
@ -5210,10 +5208,10 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
kwargs: Optional[Mapping[str, Any]] = None,
config: Optional[RunnableConfig] = None,
config_factories: Optional[
List[Callable[[RunnableConfig], RunnableConfig]]
list[Callable[[RunnableConfig], RunnableConfig]]
] = None,
custom_input_type: Optional[Union[Type[Input], BaseModel]] = None,
custom_output_type: Optional[Union[Type[Output], BaseModel]] = None,
custom_input_type: Optional[Union[type[Input], BaseModel]] = None,
custom_output_type: Optional[Union[type[Output], BaseModel]] = None,
**other_kwargs: Any,
) -> None:
"""Create a RunnableBinding from a Runnable and kwargs.
@ -5255,7 +5253,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
return self.bound.get_name(suffix, name=name)
@property
def InputType(self) -> Type[Input]:
def InputType(self) -> type[Input]:
return (
cast(Type[Input], self.custom_input_type)
if self.custom_input_type is not None
@ -5263,7 +5261,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
)
@property
def OutputType(self) -> Type[Output]:
def OutputType(self) -> type[Output]:
return (
cast(Type[Output], self.custom_output_type)
if self.custom_output_type is not None
@ -5272,20 +5270,20 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
if self.custom_input_type is not None:
return super().get_input_schema(config)
return self.bound.get_input_schema(merge_configs(self.config, config))
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
if self.custom_output_type is not None:
return super().get_output_schema(config)
return self.bound.get_output_schema(merge_configs(self.config, config))
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
return self.bound.config_specs
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
@ -5296,7 +5294,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@ -5330,12 +5328,12 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
) -> list[Output]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig],
@ -5352,12 +5350,12 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
) -> list[Output]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig],
@ -5380,7 +5378,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
*,
return_exceptions: Literal[False] = False,
**kwargs: Any,
) -> Iterator[Tuple[int, Output]]: ...
) -> Iterator[tuple[int, Output]]: ...
@overload
def batch_as_completed(
@ -5390,7 +5388,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
*,
return_exceptions: Literal[True],
**kwargs: Any,
) -> Iterator[Tuple[int, Union[Output, Exception]]]: ...
) -> Iterator[tuple[int, Union[Output, Exception]]]: ...
def batch_as_completed(
self,
@ -5399,7 +5397,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
) -> Iterator[tuple[int, Union[Output, Exception]]]:
if isinstance(config, Sequence):
configs = cast(
List[RunnableConfig],
@ -5431,7 +5429,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
*,
return_exceptions: Literal[False] = False,
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Output]]: ...
) -> AsyncIterator[tuple[int, Output]]: ...
@overload
def abatch_as_completed(
@ -5441,7 +5439,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
*,
return_exceptions: Literal[True],
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: ...
) -> AsyncIterator[tuple[int, Union[Output, Exception]]]: ...
async def abatch_as_completed(
self,
@ -5450,7 +5448,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
) -> AsyncIterator[tuple[int, Union[Output, Exception]]]:
if isinstance(config, Sequence):
configs = cast(
List[RunnableConfig],
@ -5590,7 +5588,7 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
"""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@ -5678,8 +5676,8 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
def with_types(
self,
input_type: Optional[Union[Type[Input], BaseModel]] = None,
output_type: Optional[Union[Type[Output], BaseModel]] = None,
input_type: Optional[Union[type[Input], BaseModel]] = None,
output_type: Optional[Union[type[Output], BaseModel]] = None,
) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound,

View File

@ -12,7 +12,6 @@ from typing import (
Any,
Awaitable,
Callable,
Dict,
Generator,
Iterable,
Iterator,
@ -56,13 +55,13 @@ class EmptyDict(TypedDict, total=False):
class RunnableConfig(TypedDict, total=False):
"""Configuration for a Runnable."""
tags: List[str]
tags: list[str]
"""
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
You can use these to filter calls.
"""
metadata: Dict[str, Any]
metadata: dict[str, Any]
"""
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
Keys should be strings, values should be JSON-serializable.
@ -90,7 +89,7 @@ class RunnableConfig(TypedDict, total=False):
Maximum number of times a call can recurse. If not provided, defaults to 25.
"""
configurable: Dict[str, Any]
configurable: dict[str, Any]
"""
Runtime values for attributes previously made configurable on this Runnable,
or sub-Runnables, through .configurable_fields() or .configurable_alternatives().
@ -205,7 +204,7 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
def get_config_list(
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]], length: int
) -> List[RunnableConfig]:
) -> list[RunnableConfig]:
"""Get a list of configs from a single config or a list of configs.
It is useful for subclasses overriding batch() or abatch().
@ -255,7 +254,7 @@ def patch_config(
recursion_limit: Optional[int] = None,
max_concurrency: Optional[int] = None,
run_name: Optional[str] = None,
configurable: Optional[Dict[str, Any]] = None,
configurable: Optional[dict[str, Any]] = None,
) -> RunnableConfig:
"""Patch a config with new values.

View File

@ -8,12 +8,10 @@ from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
@ -69,27 +67,27 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property
def InputType(self) -> Type[Input]:
def InputType(self) -> type[Input]:
return self.default.InputType
@property
def OutputType(self) -> Type[Output]:
def OutputType(self) -> type[Output]:
return self.default.OutputType
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
runnable, config = self.prepare(config)
return runnable.get_input_schema(config)
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
runnable, config = self.prepare(config)
return runnable.get_output_schema(config)
@ -109,7 +107,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
def prepare(
self, config: Optional[RunnableConfig] = None
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
) -> tuple[Runnable[Input, Output], RunnableConfig]:
"""Prepare the Runnable for invocation.
Args:
@ -127,7 +125,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
@abstractmethod
def _prepare(
self, config: Optional[RunnableConfig] = None
) -> Tuple[Runnable[Input, Output], RunnableConfig]: ...
) -> tuple[Runnable[Input, Output], RunnableConfig]: ...
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
@ -143,12 +141,12 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
) -> list[Output]:
configs = get_config_list(config, len(inputs))
prepared = [self.prepare(c) for c in configs]
@ -164,7 +162,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
return []
def invoke(
prepared: Tuple[Runnable[Input, Output], RunnableConfig],
prepared: tuple[Runnable[Input, Output], RunnableConfig],
input: Input,
) -> Union[Output, Exception]:
bound, config = prepared
@ -185,12 +183,12 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
) -> list[Output]:
configs = get_config_list(config, len(inputs))
prepared = [self.prepare(c) for c in configs]
@ -206,7 +204,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
return []
async def ainvoke(
prepared: Tuple[Runnable[Input, Output], RunnableConfig],
prepared: tuple[Runnable[Input, Output], RunnableConfig],
input: Input,
) -> Union[Output, Exception]:
bound, config = prepared
@ -362,15 +360,15 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
)
"""
fields: Dict[str, AnyConfigurableField]
fields: dict[str, AnyConfigurableField]
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
"""Get the configuration specs for the RunnableConfigurableFields.
Returns:
@ -412,7 +410,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
def _prepare(
self, config: Optional[RunnableConfig] = None
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
) -> tuple[Runnable[Input, Output], RunnableConfig]:
config = ensure_config(config)
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
configurable_fields = {
@ -467,7 +465,7 @@ _enums_for_spec: WeakValueDictionary[
Union[
ConfigurableFieldSingleOption, ConfigurableFieldMultiOption, ConfigurableField
],
Type[StrEnum],
type[StrEnum],
] = WeakValueDictionary()
_enums_for_spec_lock = threading.Lock()
@ -532,7 +530,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
which: ConfigurableField
"""The ConfigurableField to use to choose between alternatives."""
alternatives: Dict[
alternatives: dict[
str,
Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]],
]
@ -547,12 +545,12 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
the alternative named "gpt3" becomes "model==gpt3/temperature"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
with _enums_for_spec_lock:
if which_enum := _enums_for_spec.get(self.which):
pass
@ -612,7 +610,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
def _prepare(
self, config: Optional[RunnableConfig] = None
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
) -> tuple[Runnable[Input, Output], RunnableConfig]:
config = ensure_config(config)
which = config.get("configurable", {}).get(self.which.id, self.default_key)
# remap configurable keys for the chosen alternative

View File

@ -8,14 +8,10 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Protocol,
Sequence,
Tuple,
Type,
TypedDict,
Union,
overload,
@ -106,8 +102,8 @@ class Node(NamedTuple):
id: str
name: str
data: Union[Type[BaseModel], RunnableType]
metadata: Optional[Dict[str, Any]]
data: Union[type[BaseModel], RunnableType]
metadata: Optional[dict[str, Any]]
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
"""Return a copy of the node with optional new id and name.
@ -179,7 +175,7 @@ class MermaidDrawMethod(Enum):
API = "api" # Uses Mermaid.INK API to render the graph
def node_data_str(id: str, data: Union[Type[BaseModel], RunnableType]) -> str:
def node_data_str(id: str, data: Union[type[BaseModel], RunnableType]) -> str:
"""Convert the data of a node to a string.
Args:
@ -202,7 +198,7 @@ def node_data_str(id: str, data: Union[Type[BaseModel], RunnableType]) -> str:
def node_data_json(
node: Node, *, with_schemas: bool = False
) -> Dict[str, Union[str, Dict[str, Any]]]:
) -> dict[str, Union[str, dict[str, Any]]]:
"""Convert the data of a node to a JSON-serializable format.
Args:
@ -217,7 +213,7 @@ def node_data_json(
from langchain_core.runnables.base import Runnable, RunnableSerializable
if isinstance(node.data, RunnableSerializable):
json: Dict[str, Any] = {
json: dict[str, Any] = {
"type": "runnable",
"data": {
"id": node.data.lc_id(),
@ -265,10 +261,10 @@ class Graph:
edges: List of edges in the graph. Defaults to an empty list.
"""
nodes: Dict[str, Node] = field(default_factory=dict)
edges: List[Edge] = field(default_factory=list)
nodes: dict[str, Node] = field(default_factory=dict)
edges: list[Edge] = field(default_factory=list)
def to_json(self, *, with_schemas: bool = False) -> Dict[str, List[Dict[str, Any]]]:
def to_json(self, *, with_schemas: bool = False) -> dict[str, list[dict[str, Any]]]:
"""Convert the graph to a JSON-serializable format.
Args:
@ -282,7 +278,7 @@ class Graph:
node.id: i if is_uuid(node.id) else node.id
for i, node in enumerate(self.nodes.values())
}
edges: List[Dict[str, Any]] = []
edges: list[dict[str, Any]] = []
for edge in self.edges:
edge_dict = {
"source": stable_node_ids[edge.source],
@ -315,10 +311,10 @@ class Graph:
def add_node(
self,
data: Union[Type[BaseModel], RunnableType],
data: Union[type[BaseModel], RunnableType],
id: Optional[str] = None,
*,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
) -> Node:
"""Add a node to the graph and return it.
@ -386,7 +382,7 @@ class Graph:
def extend(
self, graph: Graph, *, prefix: str = ""
) -> Tuple[Optional[Node], Optional[Node]]:
) -> tuple[Optional[Node], Optional[Node]]:
"""Add all nodes and edges from another graph.
Note this doesn't check for duplicates, nor does it connect the graphs.
@ -622,7 +618,7 @@ def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the origin."""
targets = {edge.target for edge in graph.edges if edge.source not in exclude}
found: List[Node] = []
found: list[Node] = []
for node in graph.nodes.values():
if node.id not in exclude and node.id not in targets:
found.append(node)
@ -635,7 +631,7 @@ def _last_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the destination."""
sources = {edge.source for edge in graph.edges if edge.target not in exclude}
found: List[Node] = []
found: list[Node] = []
for node in graph.nodes.values():
if node.id not in exclude and node.id not in sources:
found.append(node)

View File

@ -6,10 +6,8 @@ from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Type,
Union,
)
@ -238,7 +236,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
history_factory_config: Sequence[ConfigurableFieldSpec]
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@ -366,7 +364,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
self._history_chain = history_chain
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
"""Get the configuration specs for the RunnableWithMessageHistory."""
return get_unique_config_specs(
super().config_specs + list(self.history_factory_config)
@ -374,10 +372,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
from langchain_core.messages import BaseMessage
fields: Dict = {}
fields: dict = {}
if self.input_messages_key and self.history_messages_key:
fields[self.input_messages_key] = (
Union[str, BaseMessage, Sequence[BaseMessage]],
@ -398,13 +396,13 @@ class RunnableWithMessageHistory(RunnableBindingBase):
)
@property
def OutputType(self) -> Type[Output]:
def OutputType(self) -> type[Output]:
output_type = self._history_chain.OutputType
return output_type
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""Get a pydantic model that can be used to validate output to the Runnable.
Runnables that leverage the configurable_fields and configurable_alternatives
@ -430,15 +428,15 @@ class RunnableWithMessageHistory(RunnableBindingBase):
module_name=self.__class__.__module__,
)
def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
def _is_not_async(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> bool:
return False
async def _is_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
async def _is_async(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> bool:
return True
def _get_input_messages(
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
) -> List[BaseMessage]:
) -> list[BaseMessage]:
from langchain_core.messages import BaseMessage
# If dictionary, try to pluck the single key representing messages
@ -481,7 +479,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
def _get_output_messages(
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
) -> List[BaseMessage]:
) -> list[BaseMessage]:
from langchain_core.messages import BaseMessage
# If dictionary, try to pluck the single key representing messages
@ -514,7 +512,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
f"Got {output_val}."
)
def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]:
def _enter_history(self, input: Any, config: RunnableConfig) -> list[BaseMessage]:
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
messages = hist.messages.copy()
@ -527,8 +525,8 @@ class RunnableWithMessageHistory(RunnableBindingBase):
return messages
async def _aenter_history(
self, input: Dict[str, Any], config: RunnableConfig
) -> List[BaseMessage]:
self, input: dict[str, Any], config: RunnableConfig
) -> list[BaseMessage]:
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
messages = (await hist.aget_messages()).copy()
@ -621,7 +619,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
return config
def _get_parameter_names(callable_: GetSessionHistoryCallable) -> List[str]:
def _get_parameter_names(callable_: GetSessionHistoryCallable) -> list[str]:
"""Get the parameter names of the callable."""
sig = inspect.signature(callable_)
return list(sig.parameters.keys())

View File

@ -13,10 +13,8 @@ from typing import (
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Type,
Union,
cast,
)
@ -144,7 +142,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
# {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20}
"""
input_type: Optional[Type[Other]] = None
input_type: Optional[type[Other]] = None
func: Optional[
Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]]
@ -180,7 +178,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
]
] = None,
*,
input_type: Optional[Type[Other]] = None,
input_type: Optional[type[Other]] = None,
**kwargs: Any,
) -> None:
if inspect.iscoroutinefunction(func):
@ -194,7 +192,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@ -210,11 +208,11 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
def assign(
cls,
**kwargs: Union[
Runnable[Dict[str, Any], Any],
Callable[[Dict[str, Any]], Any],
Runnable[dict[str, Any], Any],
Callable[[dict[str, Any]], Any],
Mapping[
str,
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
Union[Runnable[dict[str, Any], Any], Callable[[dict[str, Any]], Any]],
],
],
) -> RunnableAssign:
@ -228,7 +226,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
A Runnable that merges the Dict input with the output produced by the
mapping argument.
"""
return RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
return RunnableAssign(RunnableParallel[dict[str, Any]](kwargs))
def invoke(
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
@ -392,9 +390,9 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
# returns {'input': 5, 'add_step': {'added': 15}}
"""
mapper: RunnableParallel[Dict[str, Any]]
mapper: RunnableParallel[dict[str, Any]]
def __init__(self, mapper: RunnableParallel[Dict[str, Any]], **kwargs: Any) -> None:
def __init__(self, mapper: RunnableParallel[dict[str, Any]], **kwargs: Any) -> None:
super().__init__(mapper=mapper, **kwargs) # type: ignore[call-arg]
@classmethod
@ -402,7 +400,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@ -418,7 +416,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
map_input_schema = self.mapper.get_input_schema(config)
if not issubclass(map_input_schema, RootModel):
# ie. it's a dict
@ -428,7 +426,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
map_input_schema = self.mapper.get_input_schema(config)
map_output_schema = self.mapper.get_output_schema(config)
if not issubclass(map_input_schema, RootModel) and not issubclass(
@ -453,7 +451,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
return super().get_output_schema(config)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
return self.mapper.config_specs
def get_graph(self, config: RunnableConfig | None = None) -> Graph:
@ -470,11 +468,11 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def _invoke(
self,
input: Dict[str, Any],
input: dict[str, Any],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
assert isinstance(
input, dict
), "The input to RunnablePassthrough.assign() must be a dict."
@ -490,19 +488,19 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def invoke(
self,
input: Dict[str, Any],
input: dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
return self._call_with_config(self._invoke, input, config, **kwargs)
async def _ainvoke(
self,
input: Dict[str, Any],
input: dict[str, Any],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
assert isinstance(
input, dict
), "The input to RunnablePassthrough.assign() must be a dict."
@ -518,19 +516,19 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
async def ainvoke(
self,
input: Dict[str, Any],
input: dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
def _transform(
self,
input: Iterator[Dict[str, Any]],
input: Iterator[dict[str, Any]],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
) -> Iterator[dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps__.keys())
# create two streams, one for the map and one for the passthrough
@ -572,21 +570,21 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def transform(
self,
input: Iterator[Dict[str, Any]],
input: Iterator[dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
) -> Iterator[Dict[str, Any]]:
) -> Iterator[dict[str, Any]]:
yield from self._transform_stream_with_config(
input, self._transform, config, **kwargs
)
async def _atransform(
self,
input: AsyncIterator[Dict[str, Any]],
input: AsyncIterator[dict[str, Any]],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
) -> AsyncIterator[dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps__.keys())
# create two streams, one for the map and one for the passthrough
@ -622,10 +620,10 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
async def atransform(
self,
input: AsyncIterator[Dict[str, Any]],
input: AsyncIterator[dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
) -> AsyncIterator[dict[str, Any]]:
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, **kwargs
):
@ -633,19 +631,19 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def stream(
self,
input: Dict[str, Any],
input: dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
) -> Iterator[dict[str, Any]]:
return self.transform(iter([input]), config, **kwargs)
async def astream(
self,
input: Dict[str, Any],
input: dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
async def input_aiter() -> AsyncIterator[Dict[str, Any]]:
) -> AsyncIterator[dict[str, Any]]:
async def input_aiter() -> AsyncIterator[dict[str, Any]]:
yield input
async for chunk in self.atransform(input_aiter(), config, **kwargs):
@ -683,9 +681,9 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
print(output_data) # Output: {'name': 'John', 'age': 30}
"""
keys: Union[str, List[str]]
keys: Union[str, list[str]]
def __init__(self, keys: Union[str, List[str]], **kwargs: Any) -> None:
def __init__(self, keys: Union[str, list[str]], **kwargs: Any) -> None:
super().__init__(keys=keys, **kwargs) # type: ignore[call-arg]
@classmethod
@ -693,7 +691,7 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@ -707,7 +705,7 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
)
return super().get_name(suffix, name=name)
def _pick(self, input: Dict[str, Any]) -> Any:
def _pick(self, input: dict[str, Any]) -> Any:
assert isinstance(
input, dict
), "The input to RunnablePassthrough.assign() must be a dict."
@ -723,36 +721,36 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def _invoke(
self,
input: Dict[str, Any],
) -> Dict[str, Any]:
input: dict[str, Any],
) -> dict[str, Any]:
return self._pick(input)
def invoke(
self,
input: Dict[str, Any],
input: dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
return self._call_with_config(self._invoke, input, config, **kwargs)
async def _ainvoke(
self,
input: Dict[str, Any],
) -> Dict[str, Any]:
input: dict[str, Any],
) -> dict[str, Any]:
return self._pick(input)
async def ainvoke(
self,
input: Dict[str, Any],
input: dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
def _transform(
self,
input: Iterator[Dict[str, Any]],
) -> Iterator[Dict[str, Any]]:
input: Iterator[dict[str, Any]],
) -> Iterator[dict[str, Any]]:
for chunk in input:
picked = self._pick(chunk)
if picked is not None:
@ -760,18 +758,18 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def transform(
self,
input: Iterator[Dict[str, Any]],
input: Iterator[dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
) -> Iterator[dict[str, Any]]:
yield from self._transform_stream_with_config(
input, self._transform, config, **kwargs
)
async def _atransform(
self,
input: AsyncIterator[Dict[str, Any]],
) -> AsyncIterator[Dict[str, Any]]:
input: AsyncIterator[dict[str, Any]],
) -> AsyncIterator[dict[str, Any]]:
async for chunk in input:
picked = self._pick(chunk)
if picked is not None:
@ -779,10 +777,10 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
async def atransform(
self,
input: AsyncIterator[Dict[str, Any]],
input: AsyncIterator[dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
) -> AsyncIterator[dict[str, Any]]:
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, **kwargs
):
@ -790,19 +788,19 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def stream(
self,
input: Dict[str, Any],
input: dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
) -> Iterator[dict[str, Any]]:
return self.transform(iter([input]), config, **kwargs)
async def astream(
self,
input: Dict[str, Any],
input: dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
async def input_aiter() -> AsyncIterator[Dict[str, Any]]:
) -> AsyncIterator[dict[str, Any]]:
async def input_aiter() -> AsyncIterator[dict[str, Any]]:
yield input
async for chunk in self.atransform(input_aiter(), config, **kwargs):

View File

@ -71,7 +71,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
runnables: Mapping[str, Runnable[Any, Output]]
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec for step in self.runnables.values() for spec in step.config_specs
)
@ -94,7 +94,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@ -125,12 +125,12 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
def batch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[RouterInput],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
) -> list[Output]:
if not inputs:
return []
@ -160,12 +160,12 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
async def abatch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[RouterInput],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
) -> list[Output]:
if not inputs:
return []

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Dict, List, Literal, Sequence, Union
from typing import Any, Literal, Sequence, Union
from typing_extensions import NotRequired, TypedDict
@ -110,7 +110,7 @@ class BaseStreamEvent(TypedDict):
Each child Runnable that gets invoked as part of the execution of a parent Runnable
is assigned its own unique ID.
"""
tags: NotRequired[List[str]]
tags: NotRequired[list[str]]
"""Tags associated with the Runnable that generated this event.
Tags are always inherited from parent Runnables.
@ -118,7 +118,7 @@ class BaseStreamEvent(TypedDict):
Tags can either be bound to a Runnable using `.with_config({"tags": ["hello"]})`
or passed at run time using `.astream_events(..., {"tags": ["hello"]})`.
"""
metadata: NotRequired[Dict[str, Any]]
metadata: NotRequired[dict[str, Any]]
"""Metadata associated with the Runnable that generated this event.
Metadata can either be bound to a Runnable using

View File

@ -18,13 +18,11 @@ from typing import (
Coroutine,
Dict,
Iterable,
List,
Mapping,
NamedTuple,
Optional,
Protocol,
Sequence,
Set,
TypeVar,
Union,
)
@ -126,7 +124,7 @@ def asyncio_accepts_context() -> bool:
class IsLocalDict(ast.NodeVisitor):
"""Check if a name is a local dict."""
def __init__(self, name: str, keys: Set[str]) -> None:
def __init__(self, name: str, keys: set[str]) -> None:
"""Initialize the visitor.
Args:
@ -181,7 +179,7 @@ class IsFunctionArgDict(ast.NodeVisitor):
"""Check if the first argument of a function is a dict."""
def __init__(self) -> None:
self.keys: Set[str] = set()
self.keys: set[str] = set()
def visit_Lambda(self, node: ast.Lambda) -> Any:
"""Visit a lambda function.
@ -230,8 +228,8 @@ class NonLocals(ast.NodeVisitor):
"""Get nonlocal variables accessed."""
def __init__(self) -> None:
self.loads: Set[str] = set()
self.stores: Set[str] = set()
self.loads: set[str] = set()
self.stores: set[str] = set()
def visit_Name(self, node: ast.Name) -> Any:
"""Visit a name node.
@ -271,7 +269,7 @@ class FunctionNonLocals(ast.NodeVisitor):
"""Get the nonlocal variables accessed of a function."""
def __init__(self) -> None:
self.nonlocals: Set[str] = set()
self.nonlocals: set[str] = set()
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
"""Visit a function definition.
@ -335,7 +333,7 @@ class GetLambdaSource(ast.NodeVisitor):
self.source = ast.unparse(node)
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
def get_function_first_arg_dict_keys(func: Callable) -> Optional[list[str]]:
"""Get the keys of the first argument of a function if it is a dict.
Args:
@ -378,7 +376,7 @@ def get_lambda_source(func: Callable) -> Optional[str]:
return name
def get_function_nonlocals(func: Callable) -> List[Any]:
def get_function_nonlocals(func: Callable) -> list[Any]:
"""Get the nonlocal variables accessed by a function.
Args:
@ -392,7 +390,7 @@ def get_function_nonlocals(func: Callable) -> List[Any]:
tree = ast.parse(textwrap.dedent(code))
visitor = FunctionNonLocals()
visitor.visit(tree)
values: List[Any] = []
values: list[Any] = []
closure = inspect.getclosurevars(func)
candidates = {**closure.globals, **closure.nonlocals}
for k, v in candidates.items():
@ -608,12 +606,12 @@ class ConfigurableFieldSpec(NamedTuple):
description: Optional[str] = None
default: Any = None
is_shared: bool = False
dependencies: Optional[List[str]] = None
dependencies: Optional[list[str]] = None
def get_unique_config_specs(
specs: Iterable[ConfigurableFieldSpec],
) -> List[ConfigurableFieldSpec]:
) -> list[ConfigurableFieldSpec]:
"""Get the unique config specs from a sequence of config specs.
Args:
@ -628,7 +626,7 @@ def get_unique_config_specs(
grouped = groupby(
sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id
)
unique: List[ConfigurableFieldSpec] = []
unique: list[ConfigurableFieldSpec] = []
for id, dupes in grouped:
first = next(dupes)
others = list(dupes)

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, List, Optional, Sequence, Union
from typing import Any, Optional, Sequence, Union
from pydantic import BaseModel
@ -142,10 +142,10 @@ class Operation(FilterDirective):
"""
operator: Operator
arguments: List[FilterDirective]
arguments: list[FilterDirective]
def __init__(
self, operator: Operator, arguments: List[FilterDirective], **kwargs: Any
self, operator: Operator, arguments: list[FilterDirective], **kwargs: Any
) -> None:
# super exists from BaseModel
super().__init__( # type: ignore[call-arg]

View File

@ -13,12 +13,9 @@ from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
@ -78,11 +75,11 @@ class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
def _is_annotated_type(typ: Type[Any]) -> bool:
def _is_annotated_type(typ: type[Any]) -> bool:
return get_origin(typ) is Annotated
def _get_annotation_description(arg_type: Type) -> str | None:
def _get_annotation_description(arg_type: type) -> str | None:
if _is_annotated_type(arg_type):
annotated_args = get_args(arg_type)
for annotation in annotated_args[1:]:
@ -92,7 +89,7 @@ def _get_annotation_description(arg_type: Type) -> str | None:
def _get_filtered_args(
inferred_model: Type[BaseModel],
inferred_model: type[BaseModel],
func: Callable,
*,
filter_args: Sequence[str],
@ -112,7 +109,7 @@ def _get_filtered_args(
def _parse_python_function_docstring(
function: Callable, annotations: dict, error_on_invalid_docstring: bool = False
) -> Tuple[str, dict]:
) -> tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide.
@ -141,7 +138,7 @@ def _infer_arg_descriptions(
*,
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
) -> Tuple[str, dict]:
) -> tuple[str, dict]:
"""Infer argument descriptions from a function's docstring."""
if hasattr(inspect, "get_annotations"):
# This is for python < 3.10
@ -218,7 +215,7 @@ def create_schema_from_function(
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
include_injected: bool = True,
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""Create a pydantic schema from a function's signature.
Args:
@ -273,7 +270,7 @@ def create_schema_from_function(
filter_args_ = filter_args
else:
# Handle classmethods and instance methods
existing_params: List[str] = list(sig.parameters.keys())
existing_params: list[str] = list(sig.parameters.keys())
if existing_params and existing_params[0] in ("self", "cls") and in_class:
filter_args_ = [existing_params[0]] + list(FILTERED_ARGS)
else:
@ -395,13 +392,13 @@ class ChildTool(BaseTool):
description="Callback manager to add to the run trace.",
)
)
tags: Optional[List[str]] = None
tags: Optional[list[str]] = None
"""Optional list of tags associated with the tool. Defaults to None.
These tags will be associated with each call to this tool,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a tool with its use case.
"""
metadata: Optional[Dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
"""Optional metadata associated with the tool. Defaults to None.
This metadata will be associated with each call to this tool,
and passed as arguments to the handlers defined in `callbacks`.
@ -451,7 +448,7 @@ class ChildTool(BaseTool):
return self.get_input_schema().model_json_schema()["properties"]
@property
def tool_call_schema(self) -> Type[BaseModel]:
def tool_call_schema(self) -> type[BaseModel]:
full_schema = self.get_input_schema()
fields = []
for name, type_ in _get_all_basemodel_annotations(full_schema).items():
@ -465,7 +462,7 @@ class ChildTool(BaseTool):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""The tool's input schema.
Args:
@ -481,7 +478,7 @@ class ChildTool(BaseTool):
def invoke(
self,
input: Union[str, Dict, ToolCall],
input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
@ -490,7 +487,7 @@ class ChildTool(BaseTool):
async def ainvoke(
self,
input: Union[str, Dict, ToolCall],
input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
@ -499,7 +496,7 @@ class ChildTool(BaseTool):
# --- Tool ---
def _parse_input(self, tool_input: Union[str, Dict]) -> Union[str, Dict[str, Any]]:
def _parse_input(self, tool_input: Union[str, dict]) -> Union[str, dict[str, Any]]:
"""Convert tool input to a pydantic model.
Args:
@ -536,7 +533,7 @@ class ChildTool(BaseTool):
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
def raise_deprecation(cls, values: dict) -> Any:
"""Raise deprecation warning if callback_manager is used.
Args:
@ -574,7 +571,7 @@ class ChildTool(BaseTool):
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
return await run_in_executor(None, self._run, *args, **kwargs)
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]:
tool_input = self._parse_input(tool_input)
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
@ -585,14 +582,14 @@ class ChildTool(BaseTool):
def run(
self,
tool_input: Union[str, Dict[str, Any]],
tool_input: Union[str, dict[str, Any]],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None,
@ -696,14 +693,14 @@ class ChildTool(BaseTool):
async def arun(
self,
tool_input: Union[str, Dict],
tool_input: Union[str, dict],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None,
@ -866,7 +863,7 @@ def _prep_run_args(
input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig],
**kwargs: Any,
) -> Tuple[Union[str, Dict], Dict]:
) -> tuple[Union[str, dict], dict]:
config = ensure_config(config)
if _is_tool_call(input):
tool_call_id: Optional[str] = cast(ToolCall, input)["id"]
@ -933,7 +930,7 @@ def _stringify(content: Any) -> str:
return str(content)
def _get_type_hints(func: Callable) -> Optional[Dict[str, Type]]:
def _get_type_hints(func: Callable) -> Optional[dict[str, type]]:
if isinstance(func, functools.partial):
func = func.func
try:
@ -956,7 +953,7 @@ class InjectedToolArg:
"""Annotation for a Tool arg that is **not** meant to be generated by a model."""
def _is_injected_arg_type(type_: Type) -> bool:
def _is_injected_arg_type(type_: type) -> bool:
return any(
isinstance(arg, InjectedToolArg)
or (isinstance(arg, type) and issubclass(arg, InjectedToolArg))
@ -966,10 +963,10 @@ def _is_injected_arg_type(type_: Type) -> bool:
def _get_all_basemodel_annotations(
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
) -> Dict[str, Type]:
) -> dict[str, type]:
# cls has no subscript: cls = FooBar
if isinstance(cls, type):
annotations: Dict[str, Type] = {}
annotations: dict[str, type] = {}
for name, param in inspect.signature(cls).parameters.items():
# Exclude hidden init args added by pydantic Config. For example if
# BaseModel(extra="allow") then "extra_data" will part of init sig.
@ -979,7 +976,7 @@ def _get_all_basemodel_annotations(
) and name not in fields:
continue
annotations[name] = param.annotation
orig_bases: Tuple = getattr(cls, "__orig_bases__", tuple())
orig_bases: tuple = getattr(cls, "__orig_bases__", tuple())
# cls has subscript: cls = FooBar[int]
else:
annotations = _get_all_basemodel_annotations(
@ -1011,7 +1008,7 @@ def _get_all_basemodel_annotations(
# parent_origin = Baz,
# generic_type_vars = (type vars in Baz)
# generic_map = {type var in Baz: str}
generic_type_vars: Tuple = getattr(parent_origin, "__parameters__", tuple())
generic_type_vars: tuple = getattr(parent_origin, "__parameters__", tuple())
generic_map = {
type_var: t for type_var, t in zip(generic_type_vars, get_args(parent))
}
@ -1027,10 +1024,10 @@ def _get_all_basemodel_annotations(
def _replace_type_vars(
type_: Type,
generic_map: Optional[Dict[TypeVar, Type]] = None,
type_: type,
generic_map: Optional[dict[TypeVar, type]] = None,
default_to_bound: bool = True,
) -> Type:
) -> type:
generic_map = generic_map or {}
if isinstance(type_, TypeVar):
if type_ in generic_map:
@ -1043,7 +1040,7 @@ def _replace_type_vars(
new_args = tuple(
_replace_type_vars(arg, generic_map, default_to_bound) for arg in args
)
return _py_38_safe_origin(origin)[new_args]
return _py_38_safe_origin(origin)[new_args] # type: ignore[index]
else:
return type_
@ -1052,5 +1049,5 @@ class BaseToolkit(BaseModel, ABC):
"""Base Toolkit representing a collection of related tools."""
@abstractmethod
def get_tools(self) -> List[BaseTool]:
def get_tools(self) -> list[BaseTool]:
"""Get the tools in the toolkit."""

View File

@ -8,7 +8,7 @@ from langchain_core.tools.base import BaseTool
ToolsRenderer = Callable[[List[BaseTool]], str]
def render_text_description(tools: List[BaseTool]) -> str:
def render_text_description(tools: list[BaseTool]) -> str:
"""Render the tool name and description in plain text.
Args:
@ -36,7 +36,7 @@ def render_text_description(tools: List[BaseTool]) -> str:
return "\n".join(descriptions)
def render_text_description_and_args(tools: List[BaseTool]) -> str:
def render_text_description_and_args(tools: list[BaseTool]) -> str:
"""Render the tool name, description, and args in plain text.
Args:

View File

@ -5,10 +5,7 @@ from typing import (
Any,
Awaitable,
Callable,
Dict,
Optional,
Tuple,
Type,
Union,
)
@ -40,7 +37,7 @@ class Tool(BaseTool):
async def ainvoke(
self,
input: Union[str, Dict, ToolCall],
input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
@ -65,7 +62,7 @@ class Tool(BaseTool):
# assume it takes a single string input.
return {"tool_input": {"type": "string"}}
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]:
"""Convert tool input to pydantic model."""
args, kwargs = super()._to_args_and_kwargs(tool_input)
# For backwards compatibility. The tool must be run with a single input
@ -131,7 +128,7 @@ class Tool(BaseTool):
name: str, # We keep these required to support backwards compatibility
description: str,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
args_schema: Optional[type[BaseModel]] = None,
coroutine: Optional[
Callable[..., Awaitable[Any]]
] = None, # This is last for compatibility, but should be after func

View File

@ -6,11 +6,8 @@ from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Literal,
Optional,
Type,
Union,
)
@ -50,7 +47,7 @@ class StructuredTool(BaseTool):
# TODO: Is this needed?
async def ainvoke(
self,
input: Union[str, Dict, ToolCall],
input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
@ -112,7 +109,7 @@ class StructuredTool(BaseTool):
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
args_schema: Optional[type[BaseModel]] = None,
infer_schema: bool = True,
*,
response_format: Literal["content", "content_and_artifact"] = "content",
@ -209,7 +206,7 @@ class StructuredTool(BaseTool):
)
def _filter_schema_args(func: Callable) -> List[str]:
def _filter_schema_args(func: Callable) -> list[str]:
filter_args = list(FILTERED_ARGS)
if config_param := _get_runnable_config_param(func):
filter_args.append(config_param)

View File

@ -8,8 +8,6 @@ from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Sequence,
Union,
@ -52,13 +50,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
@ -93,13 +91,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
serialized: dict[str, Any],
prompts: list[str],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
@ -238,13 +236,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
serialized: dict[str, Any],
inputs: dict[str, Any],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
run_type: Optional[str] = None,
name: Optional[str] = None,
**kwargs: Any,
@ -282,10 +280,10 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
def on_chain_end(
self,
outputs: Dict[str, Any],
outputs: dict[str, Any],
*,
run_id: UUID,
inputs: Optional[Dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Run:
"""End a trace for a chain run.
@ -313,7 +311,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
self,
error: BaseException,
*,
inputs: Optional[Dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
run_id: UUID,
**kwargs: Any,
) -> Run:
@ -340,15 +338,15 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
def on_tool_start(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
input_str: str,
*,
run_id: UUID,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
inputs: Optional[Dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Run:
"""Start a trace for a tool run.
@ -429,13 +427,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
def on_retriever_start(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
@ -556,13 +554,13 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Any:
@ -585,13 +583,13 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
serialized: dict[str, Any],
prompts: list[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
llm_run = self._create_llm_run(
@ -642,7 +640,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
llm_run = self._complete_llm_run(
@ -658,7 +656,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
llm_run = self._errored_llm_run(
@ -670,13 +668,13 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
async def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
serialized: dict[str, Any],
inputs: dict[str, Any],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
run_type: Optional[str] = None,
name: Optional[str] = None,
**kwargs: Any,
@ -697,10 +695,10 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
async def on_chain_end(
self,
outputs: Dict[str, Any],
outputs: dict[str, Any],
*,
run_id: UUID,
inputs: Optional[Dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
chain_run = self._complete_chain_run(
@ -716,7 +714,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
self,
error: BaseException,
*,
inputs: Optional[Dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
run_id: UUID,
**kwargs: Any,
) -> None:
@ -731,15 +729,15 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
async def on_tool_start(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
input_str: str,
*,
run_id: UUID,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
inputs: Optional[Dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
tool_run = self._create_tool_run(
@ -776,7 +774,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
tool_run = self._errored_tool_run(
@ -788,13 +786,13 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
async def on_retriever_start(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
@ -819,7 +817,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
retrieval_run = self._errored_retrieval_run(
@ -839,7 +837,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
retrieval_run = self._complete_retrieval_run(

View File

@ -6,10 +6,7 @@ from typing import (
TYPE_CHECKING,
Any,
Generator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
@ -53,7 +50,7 @@ def tracing_v2_enabled(
project_name: Optional[str] = None,
*,
example_id: Optional[Union[str, UUID]] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
client: Optional[LangSmithClient] = None,
) -> Generator[LangChainTracer, None, None]:
"""Instruct LangChain to log all runs in context to LangSmith.
@ -169,11 +166,11 @@ def _get_tracer_project() -> str:
)
_configure_hooks: List[
Tuple[
_configure_hooks: list[
tuple[
ContextVar[Optional[BaseCallbackHandler]],
bool,
Optional[Type[BaseCallbackHandler]],
Optional[type[BaseCallbackHandler]],
Optional[str],
]
] = []
@ -182,7 +179,7 @@ _configure_hooks: List[
def register_configure_hook(
context_var: ContextVar[Optional[Any]],
inheritable: bool,
handle_class: Optional[Type[BaseCallbackHandler]] = None,
handle_class: Optional[type[BaseCallbackHandler]] = None,
env_var: Optional[str] = None,
) -> None:
"""Register a configure hook.

View File

@ -11,13 +11,9 @@ from typing import (
TYPE_CHECKING,
Any,
Coroutine,
Dict,
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
@ -81,9 +77,9 @@ class _TracerCore(ABC):
"""
super().__init__(**kwargs)
self._schema_format = _schema_format # For internal use only API will change.
self.run_map: Dict[str, Run] = {}
self.run_map: dict[str, Run] = {}
"""Map of run ID to run. Cleared on run end."""
self.order_map: Dict[UUID, Tuple[UUID, str]] = {}
self.order_map: dict[UUID, tuple[UUID, str]] = {}
"""Map of run ID to (trace_id, dotted_order). Cleared when tracer GCed."""
@abstractmethod
@ -137,7 +133,7 @@ class _TracerCore(ABC):
self.run_map[str(run.id)] = run
def _get_run(
self, run_id: UUID, run_type: Union[str, Set[str], None] = None
self, run_id: UUID, run_type: Union[str, set[str], None] = None
) -> Run:
try:
run = self.run_map[str(run_id)]
@ -145,7 +141,7 @@ class _TracerCore(ABC):
raise TracerException(f"No indexed run ID {run_id}.") from exc
if isinstance(run_type, str):
run_types: Union[Set[str], None] = {run_type}
run_types: Union[set[str], None] = {run_type}
else:
run_types = run_type
if run_types is not None and run.run_type not in run_types:
@ -157,12 +153,12 @@ class _TracerCore(ABC):
def _create_chat_model_run(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
run_id: UUID,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
@ -200,12 +196,12 @@ class _TracerCore(ABC):
def _create_llm_run(
self,
serialized: Dict[str, Any],
prompts: List[str],
serialized: dict[str, Any],
prompts: list[str],
run_id: UUID,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
@ -239,7 +235,7 @@ class _TracerCore(ABC):
Append token event to LLM run and return the run.
"""
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
event_kwargs: Dict[str, Any] = {"token": token}
event_kwargs: dict[str, Any] = {"token": token}
if chunk:
event_kwargs["chunk"] = chunk
llm_run.events.append(
@ -258,7 +254,7 @@ class _TracerCore(ABC):
**kwargs: Any,
) -> Run:
llm_run = self._get_run(run_id)
retry_d: Dict[str, Any] = {
retry_d: dict[str, Any] = {
"slept": retry_state.idle_for,
"attempt": retry_state.attempt_number,
}
@ -306,12 +302,12 @@ class _TracerCore(ABC):
def _create_chain_run(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
serialized: dict[str, Any],
inputs: dict[str, Any],
run_id: UUID,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
run_type: Optional[str] = None,
name: Optional[str] = None,
**kwargs: Any,
@ -358,9 +354,9 @@ class _TracerCore(ABC):
def _complete_chain_run(
self,
outputs: Dict[str, Any],
outputs: dict[str, Any],
run_id: UUID,
inputs: Optional[Dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Run:
"""Update a chain run with outputs and end time."""
@ -375,7 +371,7 @@ class _TracerCore(ABC):
def _errored_chain_run(
self,
error: BaseException,
inputs: Optional[Dict[str, Any]],
inputs: Optional[dict[str, Any]],
run_id: UUID,
**kwargs: Any,
) -> Run:
@ -389,14 +385,14 @@ class _TracerCore(ABC):
def _create_tool_run(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
input_str: str,
run_id: UUID,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
inputs: Optional[Dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Run:
"""Create a tool run."""
@ -428,7 +424,7 @@ class _TracerCore(ABC):
def _complete_tool_run(
self,
output: Dict[str, Any],
output: dict[str, Any],
run_id: UUID,
**kwargs: Any,
) -> Run:
@ -454,12 +450,12 @@ class _TracerCore(ABC):
def _create_retrieval_run(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
query: str,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:

View File

@ -6,7 +6,7 @@ import logging
import threading
import weakref
from concurrent.futures import Future, ThreadPoolExecutor, wait
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from typing import Any, List, Optional, Sequence, Union, cast
from uuid import UUID
import langsmith
@ -96,7 +96,7 @@ class EvaluatorCallbackHandler(BaseTracer):
self.futures: weakref.WeakSet[Future] = weakref.WeakSet()
self.skip_unfinished = skip_unfinished
self.project_name = project_name
self.logged_eval_results: Dict[Tuple[str, str], List[EvaluationResult]] = {}
self.logged_eval_results: dict[tuple[str, str], list[EvaluationResult]] = {}
self.lock = threading.Lock()
global _TRACERS
_TRACERS.add(self)
@ -152,7 +152,7 @@ class EvaluatorCallbackHandler(BaseTracer):
def _select_eval_results(
self,
results: Union[EvaluationResult, EvaluationResults],
) -> List[EvaluationResult]:
) -> list[EvaluationResult]:
if isinstance(results, EvaluationResult):
results_ = [results]
elif isinstance(results, dict) and "results" in results:
@ -169,10 +169,10 @@ class EvaluatorCallbackHandler(BaseTracer):
evaluator_response: Union[EvaluationResult, EvaluationResults],
run: Run,
source_run_id: Optional[UUID] = None,
) -> List[EvaluationResult]:
) -> list[EvaluationResult]:
results = self._select_eval_results(evaluator_response)
for res in results:
source_info_: Dict[str, Any] = {}
source_info_: dict[str, Any] = {}
if res.evaluator_info:
source_info_ = {**res.evaluator_info, **source_info_}
run_id_ = getattr(res, "target_run_id", None)

View File

@ -8,7 +8,6 @@ from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
@ -66,14 +65,14 @@ class RunInfo(TypedDict):
"""
name: str
tags: List[str]
metadata: Dict[str, Any]
tags: list[str]
metadata: dict[str, Any]
run_type: str
inputs: NotRequired[Any]
parent_run_id: Optional[UUID]
def _assign_name(name: Optional[str], serialized: Optional[Dict[str, Any]]) -> str:
def _assign_name(name: Optional[str], serialized: Optional[dict[str, Any]]) -> str:
"""Assign a name to a run."""
if name is not None:
return name
@ -107,15 +106,15 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
# Map of run ID to run info.
# the entry corresponding to a given run id is cleaned
# up when each corresponding run ends.
self.run_map: Dict[UUID, RunInfo] = {}
self.run_map: dict[UUID, RunInfo] = {}
# The callback event that corresponds to the end of a parent run
# may be invoked BEFORE the callback event that corresponds to the end
# of a child run, which results in clean up of run_map.
# So we keep track of the mapping between children and parent run IDs
# in a separate container. This container is GCed when the tracer is GCed.
self.parent_map: Dict[UUID, Optional[UUID]] = {}
self.parent_map: dict[UUID, Optional[UUID]] = {}
self.is_tapped: Dict[UUID, Any] = {}
self.is_tapped: dict[UUID, Any] = {}
# Filter which events will be sent over the queue.
self.root_event_filter = _RootEventFilter(
@ -132,7 +131,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
self.send_stream = memory_stream.get_send_stream()
self.receive_stream = memory_stream.get_receive_stream()
def _get_parent_ids(self, run_id: UUID) -> List[str]:
def _get_parent_ids(self, run_id: UUID) -> list[str]:
"""Get the parent IDs of a run (non-recursively) cast to strings."""
parent_ids = []
@ -269,8 +268,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
self,
run_id: UUID,
*,
tags: Optional[List[str]],
metadata: Optional[Dict[str, Any]],
tags: Optional[list[str]],
metadata: Optional[dict[str, Any]],
parent_run_id: Optional[UUID],
name_: str,
run_type: str,
@ -296,13 +295,13 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
@ -337,13 +336,13 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
serialized: dict[str, Any],
prompts: list[str],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
@ -384,8 +383,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
data: Any,
*,
run_id: UUID,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Generate a custom astream event."""
@ -456,7 +455,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
run_info = self.run_map.pop(run_id)
inputs_ = run_info["inputs"]
generations: Union[List[List[GenerationChunk]], List[List[ChatGenerationChunk]]]
generations: Union[list[list[GenerationChunk]], list[list[ChatGenerationChunk]]]
output: Union[dict, BaseMessage] = {}
if run_info["run_type"] == "chat_model":
@ -504,13 +503,13 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
async def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
serialized: dict[str, Any],
inputs: dict[str, Any],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
run_type: Optional[str] = None,
name: Optional[str] = None,
**kwargs: Any,
@ -552,10 +551,10 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
async def on_chain_end(
self,
outputs: Dict[str, Any],
outputs: dict[str, Any],
*,
run_id: UUID,
inputs: Optional[Dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""End a trace for a chain run."""
@ -586,15 +585,15 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
async def on_tool_start(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
input_str: str,
*,
run_id: UUID,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
inputs: Optional[Dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Start a trace for a tool run."""
@ -653,13 +652,13 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
async def on_retriever_start(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:

View File

@ -6,7 +6,7 @@ import logging
import warnings
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import UUID
from langsmith import Client
@ -91,7 +91,7 @@ class LangChainTracer(BaseTracer):
example_id: Optional[Union[UUID, str]] = None,
project_name: Optional[str] = None,
client: Optional[Client] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
"""Initialize the LangChain tracer.
@ -114,13 +114,13 @@ class LangChainTracer(BaseTracer):
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
@ -194,7 +194,7 @@ class LangChainTracer(BaseTracer):
)
raise ValueError("Failed to get run URL.")
def _get_tags(self, run: Run) -> List[str]:
def _get_tags(self, run: Run) -> list[str]:
"""Get combined tags for a run."""
tags = set(run.tags or [])
tags.update(self.tags or [])

View File

@ -7,9 +7,7 @@ from collections import defaultdict
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
@ -42,16 +40,16 @@ class LogEntry(TypedDict):
"""Name of the object being run."""
type: str
"""Type of the object being run, eg. prompt, chain, llm, etc."""
tags: List[str]
tags: list[str]
"""List of tags for the run."""
metadata: Dict[str, Any]
metadata: dict[str, Any]
"""Key-value pairs of metadata for the run."""
start_time: str
"""ISO-8601 timestamp of when the run started."""
streamed_output_str: List[str]
streamed_output_str: list[str]
"""List of LLM tokens streamed by this run, if applicable."""
streamed_output: List[Any]
streamed_output: list[Any]
"""List of output chunks streamed by this run, if available."""
inputs: NotRequired[Optional[Any]]
"""Inputs to this run. Not available currently via astream_log."""
@ -69,7 +67,7 @@ class RunState(TypedDict):
id: str
"""ID of the run."""
streamed_output: List[Any]
streamed_output: list[Any]
"""List of output chunks streamed by Runnable.stream()"""
final_output: Optional[Any]
"""Final output of the run, usually the result of aggregating (`+`) streamed_output.
@ -83,7 +81,7 @@ class RunState(TypedDict):
# Do we want tags/metadata on the root run? Client kinda knows it in most situations
# tags: List[str]
logs: Dict[str, LogEntry]
logs: dict[str, LogEntry]
"""Map of run names to sub-runs. If filters were supplied, this list will
contain only the runs that matched the filters."""
@ -91,14 +89,14 @@ class RunState(TypedDict):
class RunLogPatch:
"""Patch to the run log."""
ops: List[Dict[str, Any]]
ops: list[dict[str, Any]]
"""List of jsonpatch operations, which describe how to create the run state
from an empty dict. This is the minimal representation of the log, designed to
be serialized as JSON and sent over the wire to reconstruct the log on the other
side. Reconstruction of the state can be done with any jsonpatch-compliant library,
see https://jsonpatch.com for more information."""
def __init__(self, *ops: Dict[str, Any]) -> None:
def __init__(self, *ops: dict[str, Any]) -> None:
self.ops = list(ops)
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog:
@ -127,7 +125,7 @@ class RunLog(RunLogPatch):
state: RunState
"""Current state of the log, obtained from applying all ops in sequence."""
def __init__(self, *ops: Dict[str, Any], state: RunState) -> None:
def __init__(self, *ops: dict[str, Any], state: RunState) -> None:
super().__init__(*ops)
self.state = state
@ -219,14 +217,14 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
self.lock = threading.Lock()
self.send_stream = memory_stream.get_send_stream()
self.receive_stream = memory_stream.get_receive_stream()
self._key_map_by_run_id: Dict[UUID, str] = {}
self._counter_map_by_name: Dict[str, int] = defaultdict(int)
self._key_map_by_run_id: dict[UUID, str] = {}
self._counter_map_by_name: dict[str, int] = defaultdict(int)
self.root_id: Optional[UUID] = None
def __aiter__(self) -> AsyncIterator[RunLogPatch]:
return self.receive_stream.__aiter__()
def send(self, *ops: Dict[str, Any]) -> bool:
def send(self, *ops: dict[str, Any]) -> bool:
"""Send a patch to the stream, return False if the stream is closed.
Args:
@ -477,7 +475,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
def _get_standardized_inputs(
run: Run, schema_format: Literal["original", "streaming_events"]
) -> Optional[Dict[str, Any]]:
) -> Optional[dict[str, Any]]:
"""Extract standardized inputs from a run.
Standardizes the inputs based on the type of the runnable used.
@ -631,7 +629,7 @@ async def _astream_log_implementation(
except TypeError:
prev_final_output = None
final_output = chunk
patches: List[Dict[str, Any]] = []
patches: list[dict[str, Any]] = []
if with_streamed_output_list:
patches.append(
{

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import datetime
import warnings
from typing import Any, Dict, List, Optional, Type
from typing import Any, Optional
from uuid import UUID
from langsmith.schemas import RunBase as BaseRunV2
@ -18,7 +18,7 @@ from langchain_core._api import deprecated
@deprecated("0.1.0", alternative="Use string instead.", removal="1.0")
def RunTypeEnum() -> Type[RunTypeEnumDep]:
def RunTypeEnum() -> type[RunTypeEnumDep]:
"""RunTypeEnum."""
warnings.warn(
"RunTypeEnum is deprecated. Please directly use a string instead"
@ -35,7 +35,7 @@ class TracerSessionV1Base(BaseModelV1):
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
name: Optional[str] = None
extra: Optional[Dict[str, Any]] = None
extra: Optional[dict[str, Any]] = None
@deprecated("0.1.0", removal="1.0")
@ -72,10 +72,10 @@ class BaseRun(BaseModelV1):
parent_uuid: Optional[str] = None
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
end_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
extra: Optional[Dict[str, Any]] = None
extra: Optional[dict[str, Any]] = None
execution_order: int
child_execution_order: int
serialized: Dict[str, Any]
serialized: dict[str, Any]
session_id: int
error: Optional[str] = None
@ -84,7 +84,7 @@ class BaseRun(BaseModelV1):
class LLMRun(BaseRun):
"""Class for LLMRun."""
prompts: List[str]
prompts: list[str]
# Temporarily, remove but we will completely remove LLMRun
# response: Optional[LLMResult] = None
@ -93,11 +93,11 @@ class LLMRun(BaseRun):
class ChainRun(BaseRun):
"""Class for ChainRun."""
inputs: Dict[str, Any]
outputs: Optional[Dict[str, Any]] = None
child_llm_runs: List[LLMRun] = FieldV1(default_factory=list)
child_chain_runs: List[ChainRun] = FieldV1(default_factory=list)
child_tool_runs: List[ToolRun] = FieldV1(default_factory=list)
inputs: dict[str, Any]
outputs: Optional[dict[str, Any]] = None
child_llm_runs: list[LLMRun] = FieldV1(default_factory=list)
child_chain_runs: list[ChainRun] = FieldV1(default_factory=list)
child_tool_runs: list[ToolRun] = FieldV1(default_factory=list)
@deprecated("0.1.0", alternative="Run", removal="1.0")
@ -107,9 +107,9 @@ class ToolRun(BaseRun):
tool_input: str
output: Optional[str] = None
action: str
child_llm_runs: List[LLMRun] = FieldV1(default_factory=list)
child_chain_runs: List[ChainRun] = FieldV1(default_factory=list)
child_tool_runs: List[ToolRun] = FieldV1(default_factory=list)
child_llm_runs: list[LLMRun] = FieldV1(default_factory=list)
child_chain_runs: list[ChainRun] = FieldV1(default_factory=list)
child_tool_runs: list[ToolRun] = FieldV1(default_factory=list)
# Begin V2 API Schemas
@ -126,9 +126,9 @@ class Run(BaseRunV2):
dotted_order: The dotted order.
"""
child_runs: List[Run] = FieldV1(default_factory=list)
tags: Optional[List[str]] = FieldV1(default_factory=list)
events: List[Dict[str, Any]] = FieldV1(default_factory=list)
child_runs: list[Run] = FieldV1(default_factory=list)
tags: Optional[list[str]] = FieldV1(default_factory=list)
events: list[dict[str, Any]] = FieldV1(default_factory=list)
trace_id: Optional[UUID] = None
dotted_order: Optional[str] = None

View File

@ -1,9 +1,9 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from typing import Any, Optional
def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]:
def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any]:
"""Merge many dicts, handling specific scenarios where a key exists in both
dictionaries but has a value of None in 'left'. In such cases, the method uses the
value from 'right' for that key in the merged dictionary.
@ -69,7 +69,7 @@ def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]
return merged
def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]:
def merge_lists(left: Optional[list], *others: Optional[list]) -> Optional[list]:
"""Add many lists, handling None.
Args:

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import os
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
def env_var_is_set(env_var: str) -> bool:
@ -22,8 +22,8 @@ def env_var_is_set(env_var: str) -> bool:
def get_from_dict_or_env(
data: Dict[str, Any],
key: Union[str, List[str]],
data: dict[str, Any],
key: Union[str, list[str]],
env_key: str,
default: Optional[str] = None,
) -> str:

View File

@ -85,7 +85,7 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict:
removal="1.0",
)
def convert_pydantic_to_openai_function(
model: Type,
model: type,
*,
name: Optional[str] = None,
description: Optional[str] = None,
@ -106,8 +106,10 @@ def convert_pydantic_to_openai_function(
"""
if hasattr(model, "model_json_schema"):
schema = model.model_json_schema() # Pydantic 2
else:
elif hasattr(model, "schema"):
schema = model.schema() # Pydantic 1
else:
raise TypeError("Model must be a Pydantic model.")
schema = dereference_refs(schema)
if "definitions" in schema: # pydantic 1
schema.pop("definitions", None)
@ -128,7 +130,7 @@ def convert_pydantic_to_openai_function(
removal="1.0",
)
def convert_pydantic_to_openai_tool(
model: Type[BaseModel],
model: type[BaseModel],
*,
name: Optional[str] = None,
description: Optional[str] = None,
@ -194,8 +196,8 @@ def convert_python_function_to_openai_function(
)
def _convert_typed_dict_to_openai_function(typed_dict: Type) -> FunctionDescription:
visited: Dict = {}
def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescription:
visited: dict = {}
from pydantic.v1 import BaseModel
model = cast(
@ -209,11 +211,11 @@ _MAX_TYPED_DICT_RECURSION = 25
def _convert_any_typed_dicts_to_pydantic(
type_: Type,
type_: type,
*,
visited: Dict,
visited: dict,
depth: int = 0,
) -> Type:
) -> type:
from pydantic.v1 import Field as Field_v1
from pydantic.v1 import create_model as create_model_v1
@ -267,9 +269,9 @@ def _convert_any_typed_dicts_to_pydantic(
subscriptable_origin = _py_38_safe_origin(origin)
type_args = tuple(
_convert_any_typed_dicts_to_pydantic(arg, depth=depth + 1, visited=visited)
for arg in type_args
for arg in type_args # type: ignore[index]
)
return subscriptable_origin[type_args]
return subscriptable_origin[type_args] # type: ignore[index]
else:
return type_
@ -333,10 +335,10 @@ def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
def convert_to_openai_function(
function: Union[Dict[str, Any], Type, Callable, BaseTool],
function: Union[dict[str, Any], type, Callable, BaseTool],
*,
strict: Optional[bool] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Convert a raw function/class to an OpenAI function.
.. versionchanged:: 0.2.29
@ -411,10 +413,10 @@ def convert_to_openai_function(
def convert_to_openai_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
tool: Union[dict[str, Any], type[BaseModel], Callable, BaseTool],
*,
strict: Optional[bool] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Convert a raw function/class to an OpenAI tool.
.. versionchanged:: 0.2.29
@ -441,13 +443,13 @@ def convert_to_openai_tool(
if isinstance(tool, dict) and tool.get("type") == "function" and "function" in tool:
return tool
oai_function = convert_to_openai_function(tool, strict=strict)
oai_tool: Dict[str, Any] = {"type": "function", "function": oai_function}
oai_tool: dict[str, Any] = {"type": "function", "function": oai_function}
return oai_tool
def tool_example_to_messages(
input: str, tool_calls: List[BaseModel], tool_outputs: Optional[List[str]] = None
) -> List[BaseMessage]:
input: str, tool_calls: list[BaseModel], tool_outputs: Optional[list[str]] = None
) -> list[BaseMessage]:
"""Convert an example into a list of messages that can be fed into an LLM.
This code is an adapter that converts a single example to a list of messages
@ -511,7 +513,7 @@ def tool_example_to_messages(
tool_example_to_messages(txt, [tool_call])
)
"""
messages: List[BaseMessage] = [HumanMessage(content=input)]
messages: list[BaseMessage] = [HumanMessage(content=input)]
openai_tool_calls = []
for tool_call in tool_calls:
openai_tool_calls.append(
@ -540,10 +542,10 @@ def tool_example_to_messages(
def _parse_google_docstring(
docstring: Optional[str],
args: List[str],
args: list[str],
*,
error_on_invalid_docstring: bool = False,
) -> Tuple[str, dict]:
) -> tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
Assumes the function docstring follows Google Python style guide.
@ -590,12 +592,12 @@ def _parse_google_docstring(
return description, arg_descriptions
def _py_38_safe_origin(origin: Type) -> Type:
origin_union_type_map: Dict[Type, Any] = (
def _py_38_safe_origin(origin: type) -> type:
origin_union_type_map: dict[type, Any] = (
{types.UnionType: Union} if hasattr(types, "UnionType") else {}
)
origin_map: Dict[Type, Any] = {
origin_map: dict[type, Any] = {
dict: Dict,
list: List,
tuple: Tuple,
@ -610,8 +612,8 @@ def _py_38_safe_origin(origin: Type) -> Type:
def _recursive_set_additional_properties_false(
schema: Dict[str, Any],
) -> Dict[str, Any]:
schema: dict[str, Any],
) -> dict[str, Any]:
if isinstance(schema, dict):
# Check if 'required' is a key at the current level or if the schema is empty,
# in which case additionalProperties still needs to be specified.

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import json
import re
from typing import Any, Callable, List
from typing import Any, Callable
from langchain_core.exceptions import OutputParserException
@ -163,7 +163,7 @@ def _parse_json(
return parser(json_str)
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict:
"""
Parse a JSON string from a Markdown string and check that it
contains the expected keys.

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from copy import deepcopy
from typing import Any, Dict, List, Optional, Sequence, Set
from typing import Any, Optional, Sequence
def _retrieve_ref(path: str, schema: dict) -> dict:
@ -24,9 +24,9 @@ def _retrieve_ref(path: str, schema: dict) -> dict:
def _dereference_refs_helper(
obj: Any,
full_schema: Dict[str, Any],
full_schema: dict[str, Any],
skip_keys: Sequence[str],
processed_refs: Optional[Set[str]] = None,
processed_refs: Optional[set[str]] = None,
) -> Any:
if processed_refs is None:
processed_refs = set()
@ -63,8 +63,8 @@ def _dereference_refs_helper(
def _infer_skip_keys(
obj: Any, full_schema: dict, processed_refs: Optional[Set[str]] = None
) -> List[str]:
obj: Any, full_schema: dict, processed_refs: Optional[set[str]] = None
) -> list[str]:
if processed_refs is None:
processed_refs = set()

View File

@ -16,7 +16,6 @@ from typing import (
Mapping,
Optional,
Sequence,
Tuple,
Union,
cast,
)
@ -45,7 +44,7 @@ class ChevronError(SyntaxError):
#
def grab_literal(template: str, l_del: str) -> Tuple[str, str]:
def grab_literal(template: str, l_del: str) -> tuple[str, str]:
"""Parse a literal from the template.
Args:
@ -124,7 +123,7 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
return False
def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], str]:
def parse_tag(template: str, l_del: str, r_del: str) -> tuple[tuple[str, str], str]:
"""Parse a tag from a template.
Args:
@ -201,7 +200,7 @@ def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], s
def tokenize(
template: str, def_ldel: str = "{{", def_rdel: str = "}}"
) -> Iterator[Tuple[str, str]]:
) -> Iterator[tuple[str, str]]:
"""Tokenize a mustache template.
Tokenizes a mustache template in a generator fashion,
@ -427,13 +426,13 @@ def _get_partial(name: str, partials_dict: Mapping[str, str]) -> str:
#
# The main rendering function
#
g_token_cache: Dict[str, List[Tuple[str, str]]] = {}
g_token_cache: dict[str, list[tuple[str, str]]] = {}
EMPTY_DICT: MappingProxyType[str, str] = MappingProxyType({})
def render(
template: Union[str, List[Tuple[str, str]]] = "",
template: Union[str, list[tuple[str, str]]] = "",
data: Mapping[str, Any] = EMPTY_DICT,
partials_dict: Mapping[str, str] = EMPTY_DICT,
padding: str = "",
@ -490,7 +489,7 @@ def render(
if isinstance(template, Sequence) and not isinstance(template, str):
# Then we don't need to tokenize it
# But it does need to be a generator
tokens: Iterator[Tuple[str, str]] = (token for token in template)
tokens: Iterator[tuple[str, str]] = (token for token in template)
else:
if template in g_token_cache:
tokens = (token for token in g_token_cache[template])
@ -561,7 +560,7 @@ def render(
if callable(scope):
# Generate template text from tags
text = ""
tags: List[Tuple[str, str]] = []
tags: list[tuple[str, str]] = []
for token in tokens:
if token == ("end", key):
break

View File

@ -11,7 +11,6 @@ from typing import (
Any,
Callable,
Dict,
List,
Optional,
Type,
TypeVar,
@ -71,7 +70,7 @@ else:
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
def is_pydantic_v1_subclass(cls: Type) -> bool:
def is_pydantic_v1_subclass(cls: type) -> bool:
"""Check if the installed Pydantic version is 1.x-like."""
if PYDANTIC_MAJOR_VERSION == 1:
return True
@ -83,14 +82,14 @@ def is_pydantic_v1_subclass(cls: Type) -> bool:
return False
def is_pydantic_v2_subclass(cls: Type) -> bool:
def is_pydantic_v2_subclass(cls: type) -> bool:
"""Check if the installed Pydantic version is 1.x-like."""
from pydantic import BaseModel
return PYDANTIC_MAJOR_VERSION == 2 and issubclass(cls, BaseModel)
def is_basemodel_subclass(cls: Type) -> bool:
def is_basemodel_subclass(cls: type) -> bool:
"""Check if the given class is a subclass of Pydantic BaseModel.
Check if the given class is a subclass of any of the following:
@ -166,7 +165,7 @@ def pre_init(func: Callable) -> Any:
@root_validator(pre=True)
@wraps(func)
def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]:
def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> dict[str, Any]:
"""Decorator to run a function before model initialization.
Args:
@ -218,12 +217,12 @@ class _IgnoreUnserializable(GenerateJsonSchema):
def _create_subset_model_v1(
name: str,
model: Type[BaseModel],
model: type[BaseModel],
field_names: list,
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import create_model
@ -256,12 +255,12 @@ def _create_subset_model_v1(
def _create_subset_model_v2(
name: str,
model: Type[pydantic.BaseModel],
field_names: List[str],
model: type[pydantic.BaseModel],
field_names: list[str],
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> Type[pydantic.BaseModel]:
) -> type[pydantic.BaseModel]:
"""Create a pydantic model with a subset of the model fields."""
from pydantic import create_model
from pydantic.fields import FieldInfo
@ -299,11 +298,11 @@ def _create_subset_model_v2(
def _create_subset_model(
name: str,
model: TypeBaseModel,
field_names: List[str],
field_names: list[str],
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""Create subset model using the same pydantic version as the input model."""
if PYDANTIC_MAJOR_VERSION == 1:
return _create_subset_model_v1(
@ -344,25 +343,25 @@ if PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import BaseModel as BaseModelV1
@overload
def get_fields(model: Type[BaseModelV2]) -> Dict[str, FieldInfoV2]: ...
def get_fields(model: type[BaseModelV2]) -> dict[str, FieldInfoV2]: ...
@overload
def get_fields(model: BaseModelV2) -> Dict[str, FieldInfoV2]: ...
def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ...
@overload
def get_fields(model: Type[BaseModelV1]) -> Dict[str, FieldInfoV1]: ...
def get_fields(model: type[BaseModelV1]) -> dict[str, FieldInfoV1]: ...
@overload
def get_fields(model: BaseModelV1) -> Dict[str, FieldInfoV1]: ...
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...
def get_fields(
model: Union[
BaseModelV2,
BaseModelV1,
Type[BaseModelV2],
Type[BaseModelV1],
type[BaseModelV2],
type[BaseModelV1],
],
) -> Union[Dict[str, FieldInfoV2], Dict[str, FieldInfoV1]]:
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
"""Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"):
return model.model_fields # type: ignore
@ -375,8 +374,8 @@ elif PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1_
def get_fields( # type: ignore[no-redef]
model: Union[Type[BaseModelV1_], BaseModelV1_],
) -> Dict[str, FieldInfoV1]:
model: Union[type[BaseModelV1_], BaseModelV1_],
) -> dict[str, FieldInfoV1]:
"""Get the field names of a Pydantic model."""
return model.__fields__ # type: ignore
else:
@ -394,14 +393,14 @@ def _create_root_model(
type_: Any,
module_name: Optional[str] = None,
default_: object = NO_DEFAULT,
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""Create a base class."""
def schema(
cls: Type[BaseModel],
cls: type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
) -> Dict[str, Any]:
) -> dict[str, Any]:
# Complains about schema not being defined in superclass
schema_ = super(cls, cls).schema( # type: ignore[misc]
by_alias=by_alias, ref_template=ref_template
@ -410,12 +409,12 @@ def _create_root_model(
return schema_
def model_json_schema(
cls: Type[BaseModel],
cls: type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
mode: JsonSchemaMode = "validation",
) -> Dict[str, Any]:
) -> dict[str, Any]:
# Complains about model_json_schema not being defined in superclass
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
by_alias=by_alias,
@ -452,7 +451,7 @@ def _create_root_model_cached(
*,
module_name: Optional[str] = None,
default_: object = NO_DEFAULT,
) -> Type[BaseModel]:
) -> type[BaseModel]:
return _create_root_model(
model_name, type_, default_=default_, module_name=module_name
)
@ -462,7 +461,7 @@ def _create_root_model_cached(
def _create_model_cached(
__model_name: str,
**field_definitions: Any,
) -> Type[BaseModel]:
) -> type[BaseModel]:
return _create_model_base(
__model_name,
__config__=_SchemaConfig,
@ -474,7 +473,7 @@ def create_model(
__model_name: str,
__module_name: Optional[str] = None,
**field_definitions: Any,
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""Create a pydantic model with the given field definitions.
Please use create_model_v2 instead of this function.
@ -513,7 +512,7 @@ def create_model(
_RESERVED_NAMES = {key for key in dir(BaseModel) if not key.startswith("_")}
def _remap_field_definitions(field_definitions: Dict[str, Any]) -> Dict[str, Any]:
def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]:
"""This remaps fields to avoid colliding with internal pydantic fields."""
from pydantic import Field
from pydantic.fields import FieldInfo
@ -547,9 +546,9 @@ def create_model_v2(
model_name: str,
*,
module_name: Optional[str] = None,
field_definitions: Optional[Dict[str, Any]] = None,
field_definitions: Optional[dict[str, Any]] = None,
root: Optional[Any] = None,
) -> Type[BaseModel]:
) -> type[BaseModel]:
"""Create a pydantic model with the given field definitions.
Attention:

View File

@ -32,13 +32,9 @@ from typing import (
Callable,
ClassVar,
Collection,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
@ -66,13 +62,13 @@ class VectorStore(ABC):
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
metadatas: Optional[list[dict]] = None,
# One of the kwargs should be `ids` which is a list of ids
# associated with the texts.
# This is not yet enforced in the type signature for backwards compatibility
# with existing implementations.
**kwargs: Any,
) -> List[str]:
) -> list[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
@ -124,7 +120,7 @@ class VectorStore(ABC):
)
return None
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete by vector ID or other criteria.
Args:
@ -138,7 +134,7 @@ class VectorStore(ABC):
raise NotImplementedError("delete method must be implemented by subclass.")
def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
"""Get documents by their IDs.
The returned documents are expected to have the ID field set to the ID of the
@ -167,7 +163,7 @@ class VectorStore(ABC):
)
# Implementations should override this method to provide an async native version.
async def aget_by_ids(self, ids: Sequence[str], /) -> List[Document]:
async def aget_by_ids(self, ids: Sequence[str], /) -> list[Document]:
"""Async get documents by their IDs.
The returned documents are expected to have the ID field set to the ID of the
@ -194,7 +190,7 @@ class VectorStore(ABC):
return await run_in_executor(None, self.get_by_ids, ids)
async def adelete(
self, ids: Optional[List[str]] = None, **kwargs: Any
self, ids: Optional[list[str]] = None, **kwargs: Any
) -> Optional[bool]:
"""Async delete by vector ID or other criteria.
@ -211,9 +207,9 @@ class VectorStore(ABC):
async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
metadatas: Optional[list[dict]] = None,
**kwargs: Any,
) -> List[str]:
) -> list[str]:
"""Async run more texts through the embeddings and add to the vectorstore.
Args:
@ -254,7 +250,7 @@ class VectorStore(ABC):
return await self.aadd_documents(docs, **kwargs)
return await run_in_executor(None, self.add_texts, texts, metadatas, **kwargs)
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
def add_documents(self, documents: list[Document], **kwargs: Any) -> list[str]:
"""Add or update documents in the vectorstore.
Args:
@ -287,8 +283,8 @@ class VectorStore(ABC):
)
async def aadd_documents(
self, documents: List[Document], **kwargs: Any
) -> List[str]:
self, documents: list[Document], **kwargs: Any
) -> list[str]:
"""Async run more documents through the embeddings and add to
the vectorstore.
@ -318,7 +314,7 @@ class VectorStore(ABC):
return await run_in_executor(None, self.add_documents, documents, **kwargs)
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
def search(self, query: str, search_type: str, **kwargs: Any) -> list[Document]:
"""Return docs most similar to query using a specified search type.
Args:
@ -352,7 +348,7 @@ class VectorStore(ABC):
async def asearch(
self, query: str, search_type: str, **kwargs: Any
) -> List[Document]:
) -> list[Document]:
"""Async return docs most similar to query using a specified search type.
Args:
@ -386,7 +382,7 @@ class VectorStore(ABC):
@abstractmethod
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
) -> list[Document]:
"""Return docs most similar to query.
Args:
@ -442,7 +438,7 @@ class VectorStore(ABC):
def similarity_search_with_score(
self, *args: Any, **kwargs: Any
) -> List[Tuple[Document, float]]:
) -> list[tuple[Document, float]]:
"""Run similarity search with distance.
Args:
@ -456,7 +452,7 @@ class VectorStore(ABC):
async def asimilarity_search_with_score(
self, *args: Any, **kwargs: Any
) -> List[Tuple[Document, float]]:
) -> list[tuple[Document, float]]:
"""Async run similarity search with distance.
Args:
@ -479,7 +475,7 @@ class VectorStore(ABC):
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
) -> list[tuple[Document, float]]:
"""
Default similarity search with relevance scores. Modify if necessary
in subclass.
@ -506,7 +502,7 @@ class VectorStore(ABC):
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
) -> list[tuple[Document, float]]:
"""
Default similarity search with relevance scores. Modify if necessary
in subclass.
@ -533,7 +529,7 @@ class VectorStore(ABC):
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
) -> list[tuple[Document, float]]:
"""Return docs and relevance scores in the range [0, 1].
0 is dissimilar, 1 is most similar.
@ -581,7 +577,7 @@ class VectorStore(ABC):
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
) -> list[tuple[Document, float]]:
"""Async return docs and relevance scores in the range [0, 1].
0 is dissimilar, 1 is most similar.
@ -626,7 +622,7 @@ class VectorStore(ABC):
async def asimilarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
) -> list[Document]:
"""Async return docs most similar to query.
Args:
@ -644,8 +640,8 @@ class VectorStore(ABC):
return await run_in_executor(None, self.similarity_search, query, k=k, **kwargs)
def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
self, embedding: list[float], k: int = 4, **kwargs: Any
) -> list[Document]:
"""Return docs most similar to embedding vector.
Args:
@ -659,8 +655,8 @@ class VectorStore(ABC):
raise NotImplementedError
async def asimilarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
self, embedding: list[float], k: int = 4, **kwargs: Any
) -> list[Document]:
"""Async return docs most similar to embedding vector.
Args:
@ -686,7 +682,7 @@ class VectorStore(ABC):
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
@ -715,7 +711,7 @@ class VectorStore(ABC):
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
"""Async return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
@ -750,12 +746,12 @@ class VectorStore(ABC):
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
embedding: list[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
@ -779,12 +775,12 @@ class VectorStore(ABC):
async def amax_marginal_relevance_search_by_vector(
self,
embedding: List[float],
embedding: list[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
"""Async return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
@ -816,8 +812,8 @@ class VectorStore(ABC):
@classmethod
def from_documents(
cls: Type[VST],
documents: List[Document],
cls: type[VST],
documents: list[Document],
embedding: Embeddings,
**kwargs: Any,
) -> VST:
@ -837,8 +833,8 @@ class VectorStore(ABC):
@classmethod
async def afrom_documents(
cls: Type[VST],
documents: List[Document],
cls: type[VST],
documents: list[Document],
embedding: Embeddings,
**kwargs: Any,
) -> VST:
@ -859,10 +855,10 @@ class VectorStore(ABC):
@classmethod
@abstractmethod
def from_texts(
cls: Type[VST],
texts: List[str],
cls: type[VST],
texts: list[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
metadatas: Optional[list[dict]] = None,
**kwargs: Any,
) -> VST:
"""Return VectorStore initialized from texts and embeddings.
@ -880,10 +876,10 @@ class VectorStore(ABC):
@classmethod
async def afrom_texts(
cls: Type[VST],
texts: List[str],
cls: type[VST],
texts: list[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
metadatas: Optional[list[dict]] = None,
**kwargs: Any,
) -> VST:
"""Async return VectorStore initialized from texts and embeddings.
@ -902,7 +898,7 @@ class VectorStore(ABC):
None, cls.from_texts, texts, embedding, metadatas, **kwargs
)
def _get_retriever_tags(self) -> List[str]:
def _get_retriever_tags(self) -> list[str]:
"""Get tags for retriever."""
tags = [self.__class__.__name__]
if self.embeddings:
@ -991,7 +987,7 @@ class VectorStoreRetriever(BaseRetriever):
@model_validator(mode="before")
@classmethod
def validate_search_type(cls, values: Dict) -> Any:
def validate_search_type(cls, values: dict) -> Any:
"""Validate search type.
Args:
@ -1040,7 +1036,7 @@ class VectorStoreRetriever(BaseRetriever):
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
) -> list[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
elif self.search_type == "similarity_score_threshold":
@ -1060,7 +1056,7 @@ class VectorStoreRetriever(BaseRetriever):
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
) -> list[Document]:
if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search(
query, **self.search_kwargs
@ -1080,7 +1076,7 @@ class VectorStoreRetriever(BaseRetriever):
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
def add_documents(self, documents: list[Document], **kwargs: Any) -> list[str]:
"""Add documents to the vectorstore.
Args:
@ -1093,8 +1089,8 @@ class VectorStoreRetriever(BaseRetriever):
return self.vectorstore.add_documents(documents, **kwargs)
async def aadd_documents(
self, documents: List[Document], **kwargs: Any
) -> List[str]:
self, documents: list[Document], **kwargs: Any
) -> list[str]:
"""Async add documents to the vectorstore.
Args:

View File

@ -7,12 +7,9 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
)
from langchain_core._api import deprecated
@ -153,7 +150,7 @@ class InMemoryVectorStore(VectorStore):
"""
# TODO: would be nice to change to
# Dict[str, Document] at some point (will be a breaking change)
self.store: Dict[str, Dict[str, Any]] = {}
self.store: dict[str, dict[str, Any]] = {}
self.embedding = embedding
@property
@ -170,10 +167,10 @@ class InMemoryVectorStore(VectorStore):
def add_documents(
self,
documents: List[Document],
ids: Optional[List[str]] = None,
documents: list[Document],
ids: Optional[list[str]] = None,
**kwargs: Any,
) -> List[str]:
) -> list[str]:
"""Add documents to the store."""
texts = [doc.page_content for doc in documents]
vectors = self.embedding.embed_documents(texts)
@ -204,8 +201,8 @@ class InMemoryVectorStore(VectorStore):
return ids_
async def aadd_documents(
self, documents: List[Document], ids: Optional[List[str]] = None, **kwargs: Any
) -> List[str]:
self, documents: list[Document], ids: Optional[list[str]] = None, **kwargs: Any
) -> list[str]:
"""Add documents to the store."""
texts = [doc.page_content for doc in documents]
vectors = await self.embedding.aembed_documents(texts)
@ -219,7 +216,7 @@ class InMemoryVectorStore(VectorStore):
id_iterator: Iterator[Optional[str]] = (
iter(ids) if ids else iter(doc.id for doc in documents)
)
ids_: List[str] = []
ids_: list[str] = []
for doc, vector in zip(documents, vectors):
doc_id = next(id_iterator)
@ -234,7 +231,7 @@ class InMemoryVectorStore(VectorStore):
return ids_
def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
"""Get documents by their ids.
Args:
@ -313,7 +310,7 @@ class InMemoryVectorStore(VectorStore):
"failed": [],
}
async def aget_by_ids(self, ids: Sequence[str], /) -> List[Document]:
async def aget_by_ids(self, ids: Sequence[str], /) -> list[Document]:
"""Async get documents by their ids.
Args:
@ -326,11 +323,11 @@ class InMemoryVectorStore(VectorStore):
def _similarity_search_with_score_by_vector(
self,
embedding: List[float],
embedding: list[float],
k: int = 4,
filter: Optional[Callable[[Document], bool]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float, List[float]]]:
) -> list[tuple[Document, float, list[float]]]:
result = []
for doc in self.store.values():
vector = doc["vector"]
@ -351,11 +348,11 @@ class InMemoryVectorStore(VectorStore):
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
embedding: list[float],
k: int = 4,
filter: Optional[Callable[[Document], bool]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
) -> list[tuple[Document, float]]:
return [
(doc, similarity)
for doc, similarity, _ in self._similarity_search_with_score_by_vector(
@ -368,7 +365,7 @@ class InMemoryVectorStore(VectorStore):
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
) -> list[tuple[Document, float]]:
embedding = self.embedding.embed_query(query)
docs = self.similarity_search_with_score_by_vector(
embedding,
@ -379,7 +376,7 @@ class InMemoryVectorStore(VectorStore):
async def asimilarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
) -> list[tuple[Document, float]]:
embedding = await self.embedding.aembed_query(query)
docs = self.similarity_search_with_score_by_vector(
embedding,
@ -390,10 +387,10 @@ class InMemoryVectorStore(VectorStore):
def similarity_search_by_vector(
self,
embedding: List[float],
embedding: list[float],
k: int = 4,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
docs_and_scores = self.similarity_search_with_score_by_vector(
embedding,
k,
@ -402,18 +399,18 @@ class InMemoryVectorStore(VectorStore):
return [doc for doc, _ in docs_and_scores]
async def asimilarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
self, embedding: list[float], k: int = 4, **kwargs: Any
) -> list[Document]:
return self.similarity_search_by_vector(embedding, k, **kwargs)
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
) -> list[Document]:
return [doc for doc, _ in self.similarity_search_with_score(query, k, **kwargs)]
async def asimilarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
) -> list[Document]:
return [
doc
for doc, _ in await self.asimilarity_search_with_score(query, k, **kwargs)
@ -421,12 +418,12 @@ class InMemoryVectorStore(VectorStore):
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
embedding: list[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
prefetch_hits = self._similarity_search_with_score_by_vector(
embedding=embedding,
k=fetch_k,
@ -456,7 +453,7 @@ class InMemoryVectorStore(VectorStore):
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
embedding_vector = self.embedding.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding_vector,
@ -473,7 +470,7 @@ class InMemoryVectorStore(VectorStore):
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
) -> list[Document]:
embedding_vector = await self.embedding.aembed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding_vector,
@ -486,9 +483,9 @@ class InMemoryVectorStore(VectorStore):
@classmethod
def from_texts(
cls,
texts: List[str],
texts: list[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
metadatas: Optional[list[dict]] = None,
**kwargs: Any,
) -> InMemoryVectorStore:
store = cls(
@ -500,9 +497,9 @@ class InMemoryVectorStore(VectorStore):
@classmethod
async def afrom_texts(
cls,
texts: List[str],
texts: list[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
metadatas: Optional[list[dict]] = None,
**kwargs: Any,
) -> InMemoryVectorStore:
store = cls(

View File

@ -76,7 +76,7 @@ def maximal_marginal_relevance(
embedding_list: list,
lambda_mult: float = 0.5,
k: int = 4,
) -> List[int]:
) -> list[int]:
"""Calculate maximal marginal relevance.
Args:

8
libs/core/poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]]
name = "annotated-types"
@ -1186,7 +1186,7 @@ develop = true
[package.dependencies]
httpx = "^0.27.0"
langchain-core = ">=0.3.0.dev1"
langchain-core = "^0.3.0"
pytest = ">=7,<9"
syrupy = "^4"
@ -1196,7 +1196,7 @@ url = "../standard-tests"
[[package]]
name = "langchain-text-splitters"
version = "0.3.0.dev1"
version = "0.3.0"
description = "LangChain text splitting utilities"
optional = false
python-versions = ">=3.9,<4.0"
@ -1204,7 +1204,7 @@ files = []
develop = true
[package.dependencies]
langchain-core = "^0.3.0.dev1"
langchain-core = "^0.3.0"
[package.source]
type = "directory"

View File

@ -41,8 +41,8 @@ python = ">=3.12.4"
[tool.poetry.extras]
[tool.ruff.lint]
select = [ "B", "E", "F", "I", "T201", "UP",]
ignore = [ "UP006", "UP007",]
select = ["B", "E", "F", "I", "T201", "UP"]
ignore = ["UP007"]
[tool.coverage.run]
omit = [ "tests/*",]

View File

@ -3,7 +3,7 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, List
from typing import Any
from uuid import uuid4
import pytest
@ -26,7 +26,7 @@ class FakeAsyncTracer(AsyncBaseTracer):
def __init__(self) -> None:
"""Initialize the tracer."""
super().__init__()
self.runs: List[Run] = []
self.runs: list[Run] = []
async def _persist_run(self, run: Run) -> None:
self.runs.append(run)

View File

@ -3,7 +3,7 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, List
from typing import Any
from unittest.mock import MagicMock
from uuid import uuid4
@ -30,7 +30,7 @@ class FakeTracer(BaseTracer):
def __init__(self) -> None:
"""Initialize the tracer."""
super().__init__()
self.runs: List[Run] = []
self.runs: list[Run] = []
def _persist_run(self, run: Run) -> None:
"""Persist a run."""

View File

@ -7,7 +7,7 @@ the relevant methods.
from __future__ import annotations
import uuid
from typing import Any, Dict, Iterable, List, Optional, Sequence
from typing import Any, Iterable, Optional, Sequence
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
@ -18,19 +18,19 @@ class CustomAddTextsVectorstore(VectorStore):
"""A vectorstore that only implements add texts."""
def __init__(self) -> None:
self.store: Dict[str, Document] = {}
self.store: dict[str, Document] = {}
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
metadatas: Optional[list[dict]] = None,
# One of the kwargs should be `ids` which is a list of ids
# associated with the texts.
# This is not yet enforced in the type signature for backwards compatibility
# with existing implementations.
ids: Optional[List[str]] = None,
ids: Optional[list[str]] = None,
**kwargs: Any,
) -> List[str]:
) -> list[str]:
if not isinstance(texts, list):
texts = list(texts)
ids_iter = iter(ids or [])
@ -46,14 +46,14 @@ class CustomAddTextsVectorstore(VectorStore):
ids_.append(id_)
return ids_
def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
return [self.store[id] for id in ids if id in self.store]
def from_texts( # type: ignore
cls,
texts: List[str],
texts: list[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
metadatas: Optional[list[dict]] = None,
**kwargs: Any,
) -> CustomAddTextsVectorstore:
vectorstore = CustomAddTextsVectorstore()
@ -62,7 +62,7 @@ class CustomAddTextsVectorstore(VectorStore):
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
) -> list[Document]:
raise NotImplementedError()