mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
core[patch]: Add ruff rule UP006(use PEP585 annotations) (#26574)
* Added rules `UPD006` now that Pydantic is v2+ --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
2ef4c9466f
commit
3a99467ccb
@ -25,7 +25,7 @@ The schemas for the agents themselves are defined in langchain.agents.agent.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, List, Literal, Sequence, Union
|
||||
from typing import Any, Literal, Sequence, Union
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import (
|
||||
@ -71,7 +71,7 @@ class AgentAction(Serializable):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
Default is ["langchain", "schema", "agent"]."""
|
||||
return ["langchain", "schema", "agent"]
|
||||
@ -145,7 +145,7 @@ class AgentFinish(Serializable):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "agent"]
|
||||
|
||||
|
@ -23,7 +23,7 @@ Cache directly competes with Memory. See documentation for Pros and Cons.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
from langchain_core.outputs import Generation
|
||||
from langchain_core.runnables import run_in_executor
|
||||
@ -157,7 +157,7 @@ class InMemoryCache(BaseCache):
|
||||
Raises:
|
||||
ValueError: If maxsize is less than or equal to 0.
|
||||
"""
|
||||
self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {}
|
||||
self._cache: dict[tuple[str, str], RETURN_VAL_TYPE] = {}
|
||||
if maxsize is not None and maxsize <= 0:
|
||||
raise ValueError("maxsize must be greater than 0")
|
||||
self._maxsize = maxsize
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
@ -118,7 +118,7 @@ class ChainManagerMixin:
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
outputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@ -222,13 +222,13 @@ class CallbackManagerMixin:
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM starts running.
|
||||
@ -249,13 +249,13 @@ class CallbackManagerMixin:
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chat model starts running.
|
||||
@ -280,13 +280,13 @@ class CallbackManagerMixin:
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when the Retriever starts running.
|
||||
@ -303,13 +303,13 @@ class CallbackManagerMixin:
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chain starts running.
|
||||
@ -326,14 +326,14 @@ class CallbackManagerMixin:
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when the tool starts running.
|
||||
@ -393,8 +393,8 @@ class RunManagerMixin:
|
||||
data: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Override to define a handler for a custom event.
|
||||
@ -470,13 +470,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running.
|
||||
@ -497,13 +497,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chat model starts running.
|
||||
@ -533,7 +533,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled.
|
||||
@ -554,7 +554,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM ends running.
|
||||
@ -573,7 +573,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors.
|
||||
@ -590,13 +590,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when a chain starts running.
|
||||
@ -613,11 +613,11 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
outputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when a chain ends running.
|
||||
@ -636,7 +636,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors.
|
||||
@ -651,14 +651,14 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when the tool starts running.
|
||||
@ -680,7 +680,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when the tool ends running.
|
||||
@ -699,7 +699,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors.
|
||||
@ -718,7 +718,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on an arbitrary text.
|
||||
@ -754,7 +754,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on agent action.
|
||||
@ -773,7 +773,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on the agent end.
|
||||
@ -788,13 +788,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on the retriever start.
|
||||
@ -815,7 +815,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on the retriever end.
|
||||
@ -833,7 +833,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever error.
|
||||
@ -852,8 +852,8 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
data: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Override to define a handler for a custom event.
|
||||
@ -880,14 +880,14 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
||||
handlers: list[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[list[BaseCallbackHandler]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
inheritable_tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize callback manager.
|
||||
|
||||
@ -901,8 +901,8 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
Default is None.
|
||||
metadata (Optional[Dict[str, Any]]): The metadata. Default is None.
|
||||
"""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
self.inheritable_handlers: List[BaseCallbackHandler] = (
|
||||
self.handlers: list[BaseCallbackHandler] = handlers
|
||||
self.inheritable_handlers: list[BaseCallbackHandler] = (
|
||||
inheritable_handlers or []
|
||||
)
|
||||
self.parent_run_id: Optional[UUID] = parent_run_id
|
||||
@ -1002,7 +1002,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
self.inheritable_handlers.remove(handler)
|
||||
|
||||
def set_handlers(
|
||||
self, handlers: List[BaseCallbackHandler], inherit: bool = True
|
||||
self, handlers: list[BaseCallbackHandler], inherit: bool = True
|
||||
) -> None:
|
||||
"""Set handlers as the only handlers on the callback manager.
|
||||
|
||||
@ -1024,7 +1024,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
"""
|
||||
self.set_handlers([handler], inherit=inherit)
|
||||
|
||||
def add_tags(self, tags: List[str], inherit: bool = True) -> None:
|
||||
def add_tags(self, tags: list[str], inherit: bool = True) -> None:
|
||||
"""Add tags to the callback manager.
|
||||
|
||||
Args:
|
||||
@ -1038,7 +1038,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
if inherit:
|
||||
self.inheritable_tags.extend(tags)
|
||||
|
||||
def remove_tags(self, tags: List[str]) -> None:
|
||||
def remove_tags(self, tags: list[str]) -> None:
|
||||
"""Remove tags from the callback manager.
|
||||
|
||||
Args:
|
||||
@ -1048,7 +1048,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
self.tags.remove(tag)
|
||||
self.inheritable_tags.remove(tag)
|
||||
|
||||
def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None:
|
||||
def add_metadata(self, metadata: dict[str, Any], inherit: bool = True) -> None:
|
||||
"""Add metadata to the callback manager.
|
||||
|
||||
Args:
|
||||
@ -1059,7 +1059,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
if inherit:
|
||||
self.inheritable_metadata.update(metadata)
|
||||
|
||||
def remove_metadata(self, keys: List[str]) -> None:
|
||||
def remove_metadata(self, keys: list[str]) -> None:
|
||||
"""Remove metadata from the callback manager.
|
||||
|
||||
Args:
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional, TextIO, cast
|
||||
from typing import Any, Optional, TextIO, cast
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
@ -35,7 +35,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
self.file.close()
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain.
|
||||
|
||||
@ -51,7 +51,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
file=self.file,
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain.
|
||||
|
||||
Args:
|
||||
|
@ -14,9 +14,7 @@ from typing import (
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
@ -64,12 +62,12 @@ def trace_as_chain_group(
|
||||
group_name: str,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
project_name: Optional[str] = None,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> Generator[CallbackManagerForChainGroup, None, None]:
|
||||
"""Get a callback manager for a chain group in a context manager.
|
||||
Useful for grouping different calls together as a single run even if
|
||||
@ -144,12 +142,12 @@ async def atrace_as_chain_group(
|
||||
group_name: str,
|
||||
callback_manager: Optional[AsyncCallbackManager] = None,
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
project_name: Optional[str] = None,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]:
|
||||
"""Get an async callback manager for a chain group in a context manager.
|
||||
Useful for grouping different async calls together as a single run even if
|
||||
@ -240,7 +238,7 @@ def shielded(func: Func) -> Func:
|
||||
|
||||
|
||||
def handle_event(
|
||||
handlers: List[BaseCallbackHandler],
|
||||
handlers: list[BaseCallbackHandler],
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
*args: Any,
|
||||
@ -258,10 +256,10 @@ def handle_event(
|
||||
*args: The arguments to pass to the event handler.
|
||||
**kwargs: The keyword arguments to pass to the event handler
|
||||
"""
|
||||
coros: List[Coroutine[Any, Any, Any]] = []
|
||||
coros: list[Coroutine[Any, Any, Any]] = []
|
||||
|
||||
try:
|
||||
message_strings: Optional[List[str]] = None
|
||||
message_strings: Optional[list[str]] = None
|
||||
for handler in handlers:
|
||||
try:
|
||||
if ignore_condition_name is None or not getattr(
|
||||
@ -318,7 +316,7 @@ def handle_event(
|
||||
_run_coros(coros)
|
||||
|
||||
|
||||
def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
|
||||
def _run_coros(coros: list[Coroutine[Any, Any, Any]]) -> None:
|
||||
if hasattr(asyncio, "Runner"):
|
||||
# Python 3.11+
|
||||
# Run the coroutines in a new event loop, taking care to
|
||||
@ -399,7 +397,7 @@ async def _ahandle_event_for_handler(
|
||||
|
||||
|
||||
async def ahandle_event(
|
||||
handlers: List[BaseCallbackHandler],
|
||||
handlers: list[BaseCallbackHandler],
|
||||
event_name: str,
|
||||
ignore_condition_name: Optional[str],
|
||||
*args: Any,
|
||||
@ -446,13 +444,13 @@ class BaseRunManager(RunManagerMixin):
|
||||
self,
|
||||
*,
|
||||
run_id: UUID,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: List[BaseCallbackHandler],
|
||||
handlers: list[BaseCallbackHandler],
|
||||
inheritable_handlers: list[BaseCallbackHandler],
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
inheritable_tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize the run manager.
|
||||
|
||||
@ -481,7 +479,7 @@ class BaseRunManager(RunManagerMixin):
|
||||
self.inheritable_metadata = inheritable_metadata or {}
|
||||
|
||||
@classmethod
|
||||
def get_noop_manager(cls: Type[BRM]) -> BRM:
|
||||
def get_noop_manager(cls: type[BRM]) -> BRM:
|
||||
"""Return a manager that doesn't perform any operations.
|
||||
|
||||
Returns:
|
||||
@ -824,7 +822,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||
"""Callback manager for chain run."""
|
||||
|
||||
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: Union[dict[str, Any], Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running.
|
||||
|
||||
Args:
|
||||
@ -929,7 +927,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
|
||||
@shielded
|
||||
async def on_chain_end(
|
||||
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
||||
self, outputs: Union[dict[str, Any], Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when a chain ends running.
|
||||
|
||||
@ -1248,11 +1246,11 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[CallbackManagerForLLMRun]:
|
||||
) -> list[CallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running.
|
||||
|
||||
Args:
|
||||
@ -1299,11 +1297,11 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[CallbackManagerForLLMRun]:
|
||||
) -> list[CallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running.
|
||||
|
||||
Args:
|
||||
@ -1354,8 +1352,8 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
inputs: Union[dict[str, Any], Any],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForChainRun:
|
||||
@ -1398,11 +1396,11 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
input_str: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForToolRun:
|
||||
"""Run when tool starts running.
|
||||
@ -1453,7 +1451,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
query: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@ -1541,10 +1539,10 @@ class CallbackManager(BaseCallbackManager):
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
local_callbacks: Callbacks = None,
|
||||
verbose: bool = False,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
local_tags: Optional[List[str]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
local_metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_tags: Optional[list[str]] = None,
|
||||
local_tags: Optional[list[str]] = None,
|
||||
inheritable_metadata: Optional[dict[str, Any]] = None,
|
||||
local_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> CallbackManager:
|
||||
"""Configure the callback manager.
|
||||
|
||||
@ -1583,8 +1581,8 @@ class CallbackManagerForChainGroup(CallbackManager):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
||||
handlers: list[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[list[BaseCallbackHandler]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
*,
|
||||
parent_run_manager: CallbackManagerForChainRun,
|
||||
@ -1681,7 +1679,7 @@ class CallbackManagerForChainGroup(CallbackManager):
|
||||
manager.add_handler(handler, inherit=True)
|
||||
return manager
|
||||
|
||||
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: Union[dict[str, Any], Any], **kwargs: Any) -> None:
|
||||
"""Run when traced chain group ends.
|
||||
|
||||
Args:
|
||||
@ -1716,11 +1714,11 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[AsyncCallbackManagerForLLMRun]:
|
||||
) -> list[AsyncCallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running.
|
||||
|
||||
Args:
|
||||
@ -1779,11 +1777,11 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[AsyncCallbackManagerForLLMRun]:
|
||||
) -> list[AsyncCallbackManagerForLLMRun]:
|
||||
"""Async run when LLM starts running.
|
||||
|
||||
Args:
|
||||
@ -1840,8 +1838,8 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
inputs: Union[dict[str, Any], Any],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCallbackManagerForChainRun:
|
||||
@ -1886,7 +1884,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
input_str: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@ -1975,7 +1973,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
serialized: Optional[dict[str, Any]],
|
||||
query: str,
|
||||
run_id: Optional[UUID] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@ -2027,10 +2025,10 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
local_callbacks: Callbacks = None,
|
||||
verbose: bool = False,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
local_tags: Optional[List[str]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
local_metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_tags: Optional[list[str]] = None,
|
||||
local_tags: Optional[list[str]] = None,
|
||||
inheritable_metadata: Optional[dict[str, Any]] = None,
|
||||
local_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncCallbackManager:
|
||||
"""Configure the async callback manager.
|
||||
|
||||
@ -2069,8 +2067,8 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: List[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
||||
handlers: list[BaseCallbackHandler],
|
||||
inheritable_handlers: Optional[list[BaseCallbackHandler]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
*,
|
||||
parent_run_manager: AsyncCallbackManagerForChainRun,
|
||||
@ -2169,7 +2167,7 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
|
||||
return manager
|
||||
|
||||
async def on_chain_end(
|
||||
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
||||
self, outputs: Union[dict[str, Any], Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when traced chain group ends.
|
||||
|
||||
@ -2202,14 +2200,14 @@ H = TypeVar("H", bound=BaseCallbackHandler, covariant=True)
|
||||
|
||||
|
||||
def _configure(
|
||||
callback_manager_cls: Type[T],
|
||||
callback_manager_cls: type[T],
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
local_callbacks: Callbacks = None,
|
||||
verbose: bool = False,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
local_tags: Optional[List[str]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
local_metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_tags: Optional[list[str]] = None,
|
||||
local_tags: Optional[list[str]] = None,
|
||||
inheritable_metadata: Optional[dict[str, Any]] = None,
|
||||
local_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> T:
|
||||
"""Configure the callback manager.
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
from langchain_core.utils import print_text
|
||||
@ -23,7 +23,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
self.color = color
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain.
|
||||
|
||||
@ -35,7 +35,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
|
||||
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") # noqa: T201
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain.
|
||||
|
||||
Args:
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
|
||||
@ -17,7 +17,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler for streaming. Only works with LLMs that support streaming."""
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running.
|
||||
|
||||
@ -29,8 +29,8 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running.
|
||||
@ -68,7 +68,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when a chain starts running.
|
||||
|
||||
@ -78,7 +78,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when a chain ends running.
|
||||
|
||||
Args:
|
||||
@ -95,7 +95,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
self, serialized: dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when the tool starts running.
|
||||
|
||||
|
@ -18,7 +18,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Sequence, Union
|
||||
from typing import Sequence, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -87,7 +87,7 @@ class BaseChatMessageHistory(ABC):
|
||||
f.write("[]")
|
||||
"""
|
||||
|
||||
messages: List[BaseMessage]
|
||||
messages: list[BaseMessage]
|
||||
"""A property or attribute that returns a list of messages.
|
||||
|
||||
In general, getting the messages may involve IO to the underlying
|
||||
@ -95,7 +95,7 @@ class BaseChatMessageHistory(ABC):
|
||||
latency.
|
||||
"""
|
||||
|
||||
async def aget_messages(self) -> List[BaseMessage]:
|
||||
async def aget_messages(self) -> list[BaseMessage]:
|
||||
"""Async version of getting messages.
|
||||
|
||||
Can over-ride this method to provide an efficient async implementation.
|
||||
@ -204,10 +204,10 @@ class InMemoryChatMessageHistory(BaseChatMessageHistory, BaseModel):
|
||||
Stores messages in a memory list.
|
||||
"""
|
||||
|
||||
messages: List[BaseMessage] = Field(default_factory=list)
|
||||
messages: list[BaseMessage] = Field(default_factory=list)
|
||||
"""A list of messages stored in memory."""
|
||||
|
||||
async def aget_messages(self) -> List[BaseMessage]:
|
||||
async def aget_messages(self) -> list[BaseMessage]:
|
||||
"""Async version of getting messages.
|
||||
|
||||
Can over-ride this method to provide an efficient async implementation.
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Optional
|
||||
from typing import TYPE_CHECKING, AsyncIterator, Iterator, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.runnables import run_in_executor
|
||||
@ -25,17 +25,17 @@ class BaseLoader(ABC): # noqa: B024
|
||||
|
||||
# Sub-classes should not implement this method directly. Instead, they
|
||||
# should implement the lazy load method.
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
"""Load data into Document objects."""
|
||||
return list(self.lazy_load())
|
||||
|
||||
async def aload(self) -> List[Document]:
|
||||
async def aload(self) -> list[Document]:
|
||||
"""Load data into Document objects."""
|
||||
return [document async for document in self.alazy_load()]
|
||||
|
||||
def load_and_split(
|
||||
self, text_splitter: Optional[TextSplitter] = None
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Load Documents and split into chunks. Chunks are returned as Documents.
|
||||
|
||||
Do not override this method. It should be considered to be deprecated!
|
||||
@ -108,7 +108,7 @@ class BaseBlobParser(ABC):
|
||||
Generator of documents
|
||||
"""
|
||||
|
||||
def parse(self, blob: Blob) -> List[Document]:
|
||||
def parse(self, blob: Blob) -> list[Document]:
|
||||
"""Eagerly parse the blob into a document or documents.
|
||||
|
||||
This is a convenience method for interactive development environment.
|
||||
|
@ -4,7 +4,7 @@ import contextlib
|
||||
import mimetypes
|
||||
from io import BufferedReader, BytesIO
|
||||
from pathlib import PurePath
|
||||
from typing import Any, Dict, Generator, List, Literal, Optional, Union, cast
|
||||
from typing import Any, Generator, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
@ -138,7 +138,7 @@ class Blob(BaseMedia):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_blob_is_valid(cls, values: Dict[str, Any]) -> Any:
|
||||
def check_blob_is_valid(cls, values: dict[str, Any]) -> Any:
|
||||
"""Verify that either data or path is provided."""
|
||||
if "data" not in values and "path" not in values:
|
||||
raise ValueError("Either data or path must be provided")
|
||||
@ -285,7 +285,7 @@ class Document(BaseMedia):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "document"]
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
|
||||
def sorted_values(values: Dict[str, str]) -> List[Any]:
|
||||
def sorted_values(values: dict[str, str]) -> list[Any]:
|
||||
"""Return a list of values in dict sorted by key.
|
||||
|
||||
Args:
|
||||
@ -35,12 +35,12 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
"""VectorStore that contains information about examples."""
|
||||
k: int = 4
|
||||
"""Number of examples to select."""
|
||||
example_keys: Optional[List[str]] = None
|
||||
example_keys: Optional[list[str]] = None
|
||||
"""Optional keys to filter examples to."""
|
||||
input_keys: Optional[List[str]] = None
|
||||
input_keys: Optional[list[str]] = None
|
||||
"""Optional keys to filter input to. If provided, the search is based on
|
||||
the input variables instead of all variables."""
|
||||
vectorstore_kwargs: Optional[Dict[str, Any]] = None
|
||||
vectorstore_kwargs: Optional[dict[str, Any]] = None
|
||||
"""Extra arguments passed to similarity_search function of the vectorstore."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
@ -50,14 +50,14 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
|
||||
@staticmethod
|
||||
def _example_to_text(
|
||||
example: Dict[str, str], input_keys: Optional[List[str]]
|
||||
example: dict[str, str], input_keys: Optional[list[str]]
|
||||
) -> str:
|
||||
if input_keys:
|
||||
return " ".join(sorted_values({key: example[key] for key in input_keys}))
|
||||
else:
|
||||
return " ".join(sorted_values(example))
|
||||
|
||||
def _documents_to_examples(self, documents: List[Document]) -> List[dict]:
|
||||
def _documents_to_examples(self, documents: list[Document]) -> list[dict]:
|
||||
# Get the examples from the metadata.
|
||||
# This assumes that examples are stored in metadata.
|
||||
examples = [dict(e.metadata) for e in documents]
|
||||
@ -66,7 +66,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
examples = [{k: eg[k] for k in self.example_keys} for eg in examples]
|
||||
return examples
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> str:
|
||||
def add_example(self, example: dict[str, str]) -> str:
|
||||
"""Add a new example to vectorstore.
|
||||
|
||||
Args:
|
||||
@ -81,7 +81,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
)
|
||||
return ids[0]
|
||||
|
||||
async def aadd_example(self, example: Dict[str, str]) -> str:
|
||||
async def aadd_example(self, example: dict[str, str]) -> str:
|
||||
"""Async add new example to vectorstore.
|
||||
|
||||
Args:
|
||||
@ -100,7 +100,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
|
||||
"""Select examples based on semantic similarity."""
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Select examples based on semantic similarity.
|
||||
|
||||
Args:
|
||||
@ -118,7 +118,7 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
|
||||
)
|
||||
return self._documents_to_examples(example_docs)
|
||||
|
||||
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Asynchronously select examples based on semantic similarity.
|
||||
|
||||
Args:
|
||||
@ -139,13 +139,13 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
|
||||
@classmethod
|
||||
def from_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
examples: list[dict],
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
vectorstore_cls: type[VectorStore],
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
input_keys: Optional[list[str]] = None,
|
||||
*,
|
||||
example_keys: Optional[List[str]] = None,
|
||||
example_keys: Optional[list[str]] = None,
|
||||
vectorstore_kwargs: Optional[dict] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> SemanticSimilarityExampleSelector:
|
||||
@ -183,13 +183,13 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
|
||||
@classmethod
|
||||
async def afrom_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
examples: list[dict],
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
vectorstore_cls: type[VectorStore],
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
input_keys: Optional[list[str]] = None,
|
||||
*,
|
||||
example_keys: Optional[List[str]] = None,
|
||||
example_keys: Optional[list[str]] = None,
|
||||
vectorstore_kwargs: Optional[dict] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> SemanticSimilarityExampleSelector:
|
||||
@ -235,7 +235,7 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
|
||||
fetch_k: int = 20
|
||||
"""Number of examples to fetch to rerank."""
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Select examples based on Max Marginal Relevance.
|
||||
|
||||
Args:
|
||||
@ -251,7 +251,7 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
|
||||
)
|
||||
return self._documents_to_examples(example_docs)
|
||||
|
||||
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Asynchronously select examples based on Max Marginal Relevance.
|
||||
|
||||
Args:
|
||||
@ -270,13 +270,13 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
|
||||
@classmethod
|
||||
def from_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
examples: list[dict],
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
vectorstore_cls: type[VectorStore],
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
input_keys: Optional[list[str]] = None,
|
||||
fetch_k: int = 20,
|
||||
example_keys: Optional[List[str]] = None,
|
||||
example_keys: Optional[list[str]] = None,
|
||||
vectorstore_kwargs: Optional[dict] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> MaxMarginalRelevanceExampleSelector:
|
||||
@ -317,14 +317,14 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
|
||||
@classmethod
|
||||
async def afrom_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
examples: list[dict],
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
vectorstore_cls: type[VectorStore],
|
||||
*,
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
input_keys: Optional[list[str]] = None,
|
||||
fetch_k: int = 20,
|
||||
example_keys: Optional[List[str]] = None,
|
||||
example_keys: Optional[list[str]] = None,
|
||||
vectorstore_kwargs: Optional[dict] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> MaxMarginalRelevanceExampleSelector:
|
||||
|
@ -8,7 +8,6 @@ from typing import (
|
||||
Collection,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
@ -68,7 +67,7 @@ class Node(Serializable):
|
||||
"""Text contained by the node."""
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
"""Metadata for the node."""
|
||||
links: List[Link] = Field(default_factory=list)
|
||||
links: list[Link] = Field(default_factory=list)
|
||||
"""Links associated with the node."""
|
||||
|
||||
|
||||
@ -189,7 +188,7 @@ class GraphVectorStore(VectorStore):
|
||||
*,
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the metadata field `links` will be extracted to create
|
||||
@ -237,7 +236,7 @@ class GraphVectorStore(VectorStore):
|
||||
*,
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the metadata field `links` will be extracted to create
|
||||
@ -282,7 +281,7 @@ class GraphVectorStore(VectorStore):
|
||||
self,
|
||||
documents: Iterable[Document],
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the document metadata field `links` will be extracted to
|
||||
@ -332,7 +331,7 @@ class GraphVectorStore(VectorStore):
|
||||
self,
|
||||
documents: Iterable[Document],
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the document metadata field `links` will be extracted to
|
||||
@ -535,7 +534,7 @@ class GraphVectorStore(VectorStore):
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
return list(self.traversal_search(query, k=k, depth=0))
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
@ -545,7 +544,7 @@ class GraphVectorStore(VectorStore):
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
return list(
|
||||
self.mmr_traversal_search(
|
||||
query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0
|
||||
@ -554,10 +553,10 @@ class GraphVectorStore(VectorStore):
|
||||
|
||||
async def asimilarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
return [doc async for doc in self.atraversal_search(query, k=k, depth=0)]
|
||||
|
||||
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
|
||||
def search(self, query: str, search_type: str, **kwargs: Any) -> list[Document]:
|
||||
if search_type == "similarity":
|
||||
return self.similarity_search(query, **kwargs)
|
||||
elif search_type == "similarity_score_threshold":
|
||||
@ -580,7 +579,7 @@ class GraphVectorStore(VectorStore):
|
||||
|
||||
async def asearch(
|
||||
self, query: str, search_type: str, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
if search_type == "similarity":
|
||||
return await self.asimilarity_search(query, **kwargs)
|
||||
elif search_type == "similarity_score_threshold":
|
||||
@ -679,7 +678,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
if self.search_type == "traversal":
|
||||
return list(self.vectorstore.traversal_search(query, **self.search_kwargs))
|
||||
elif self.search_type == "mmr_traversal":
|
||||
@ -691,7 +690,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
if self.search_type == "traversal":
|
||||
return [
|
||||
doc
|
||||
|
@ -11,14 +11,11 @@ from typing import (
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -71,7 +68,7 @@ class _HashedDocument(Document):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def calculate_hashes(cls, values: Dict[str, Any]) -> Any:
|
||||
def calculate_hashes(cls, values: dict[str, Any]) -> Any:
|
||||
"""Root validator to calculate content and metadata hash."""
|
||||
content = values.get("page_content", "")
|
||||
metadata = values.get("metadata", {})
|
||||
@ -125,7 +122,7 @@ class _HashedDocument(Document):
|
||||
)
|
||||
|
||||
|
||||
def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
|
||||
def _batch(size: int, iterable: Iterable[T]) -> Iterator[list[T]]:
|
||||
"""Utility batching function."""
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
@ -135,9 +132,9 @@ def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
|
||||
yield chunk
|
||||
|
||||
|
||||
async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[List[T]]:
|
||||
async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[list[T]]:
|
||||
"""Utility batching function."""
|
||||
batch: List[T] = []
|
||||
batch: list[T] = []
|
||||
async for element in iterable:
|
||||
if len(batch) < size:
|
||||
batch.append(element)
|
||||
@ -171,7 +168,7 @@ def _deduplicate_in_order(
|
||||
hashed_documents: Iterable[_HashedDocument],
|
||||
) -> Iterator[_HashedDocument]:
|
||||
"""Deduplicate a list of hashed documents while preserving order."""
|
||||
seen: Set[str] = set()
|
||||
seen: set[str] = set()
|
||||
|
||||
for hashed_doc in hashed_documents:
|
||||
if hashed_doc.hash_ not in seen:
|
||||
@ -349,7 +346,7 @@ def index(
|
||||
uids = []
|
||||
docs_to_index = []
|
||||
uids_to_refresh = []
|
||||
seen_docs: Set[str] = set()
|
||||
seen_docs: set[str] = set()
|
||||
for hashed_doc, doc_exists in zip(hashed_docs, exists_batch):
|
||||
if doc_exists:
|
||||
if force_update:
|
||||
@ -589,7 +586,7 @@ async def aindex(
|
||||
uids: list[str] = []
|
||||
docs_to_index: list[Document] = []
|
||||
uids_to_refresh = []
|
||||
seen_docs: Set[str] = set()
|
||||
seen_docs: set[str] = set()
|
||||
for hashed_doc, doc_exists in zip(hashed_docs, exists_batch):
|
||||
if doc_exists:
|
||||
if force_update:
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import abc
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Sequence, TypedDict
|
||||
from typing import Any, Optional, Sequence, TypedDict
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
@ -144,7 +144,7 @@ class RecordManager(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, keys: Sequence[str]) -> List[bool]:
|
||||
def exists(self, keys: Sequence[str]) -> list[bool]:
|
||||
"""Check if the provided keys exist in the database.
|
||||
|
||||
Args:
|
||||
@ -155,7 +155,7 @@ class RecordManager(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def aexists(self, keys: Sequence[str]) -> List[bool]:
|
||||
async def aexists(self, keys: Sequence[str]) -> list[bool]:
|
||||
"""Asynchronously check if the provided keys exist in the database.
|
||||
|
||||
Args:
|
||||
@ -173,7 +173,7 @@ class RecordManager(ABC):
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""List records in the database based on the provided filters.
|
||||
|
||||
Args:
|
||||
@ -194,7 +194,7 @@ class RecordManager(ABC):
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Asynchronously list records in the database based on the provided filters.
|
||||
|
||||
Args:
|
||||
@ -241,7 +241,7 @@ class InMemoryRecordManager(RecordManager):
|
||||
super().__init__(namespace)
|
||||
# Each key points to a dictionary
|
||||
# of {'group_id': group_id, 'updated_at': timestamp}
|
||||
self.records: Dict[str, _Record] = {}
|
||||
self.records: dict[str, _Record] = {}
|
||||
self.namespace = namespace
|
||||
|
||||
def create_schema(self) -> None:
|
||||
@ -325,7 +325,7 @@ class InMemoryRecordManager(RecordManager):
|
||||
"""
|
||||
self.update(keys, group_ids=group_ids, time_at_least=time_at_least)
|
||||
|
||||
def exists(self, keys: Sequence[str]) -> List[bool]:
|
||||
def exists(self, keys: Sequence[str]) -> list[bool]:
|
||||
"""Check if the provided keys exist in the database.
|
||||
|
||||
Args:
|
||||
@ -336,7 +336,7 @@ class InMemoryRecordManager(RecordManager):
|
||||
"""
|
||||
return [key in self.records for key in keys]
|
||||
|
||||
async def aexists(self, keys: Sequence[str]) -> List[bool]:
|
||||
async def aexists(self, keys: Sequence[str]) -> list[bool]:
|
||||
"""Async check if the provided keys exist in the database.
|
||||
|
||||
Args:
|
||||
@ -354,7 +354,7 @@ class InMemoryRecordManager(RecordManager):
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""List records in the database based on the provided filters.
|
||||
|
||||
Args:
|
||||
@ -390,7 +390,7 @@ class InMemoryRecordManager(RecordManager):
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Async list records in the database based on the provided filters.
|
||||
|
||||
Args:
|
||||
@ -449,9 +449,9 @@ class UpsertResponse(TypedDict):
|
||||
indexed to avoid this issue.
|
||||
"""
|
||||
|
||||
succeeded: List[str]
|
||||
succeeded: list[str]
|
||||
"""The IDs that were successfully indexed."""
|
||||
failed: List[str]
|
||||
failed: list[str]
|
||||
"""The IDs that failed to index."""
|
||||
|
||||
|
||||
@ -562,7 +562,7 @@ class DocumentIndex(BaseRetriever):
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> DeleteResponse:
|
||||
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse:
|
||||
"""Delete by IDs or other criteria.
|
||||
|
||||
Calling delete without any input parameters should raise a ValueError!
|
||||
@ -579,7 +579,7 @@ class DocumentIndex(BaseRetriever):
|
||||
"""
|
||||
|
||||
async def adelete(
|
||||
self, ids: Optional[List[str]] = None, **kwargs: Any
|
||||
self, ids: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> DeleteResponse:
|
||||
"""Delete by IDs or other criteria. Async variant.
|
||||
|
||||
@ -607,7 +607,7 @@ class DocumentIndex(BaseRetriever):
|
||||
ids: Sequence[str],
|
||||
/,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get documents by id.
|
||||
|
||||
Fewer documents may be returned than requested if some IDs are not found or
|
||||
@ -633,7 +633,7 @@ class DocumentIndex(BaseRetriever):
|
||||
ids: Sequence[str],
|
||||
/,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get documents by id.
|
||||
|
||||
Fewer documents may be returned than requested if some IDs are not found or
|
||||
|
@ -6,14 +6,11 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
@ -51,7 +48,7 @@ class LangSmithParams(TypedDict, total=False):
|
||||
"""Temperature for generation."""
|
||||
ls_max_tokens: Optional[int]
|
||||
"""Max tokens for generation."""
|
||||
ls_stop: Optional[List[str]]
|
||||
ls_stop: Optional[list[str]]
|
||||
"""Stop words for generation."""
|
||||
|
||||
|
||||
@ -74,7 +71,7 @@ def get_tokenizer() -> Any:
|
||||
return GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
|
||||
def _get_token_ids_default_method(text: str) -> List[int]:
|
||||
def _get_token_ids_default_method(text: str) -> list[int]:
|
||||
"""Encode the text into token IDs."""
|
||||
# get the cached tokenizer
|
||||
tokenizer = get_tokenizer()
|
||||
@ -117,11 +114,11 @@ class BaseLanguageModel(
|
||||
"""Whether to print out response text."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
"""Callbacks to add to the run trace."""
|
||||
tags: Optional[List[str]] = Field(default=None, exclude=True)
|
||||
tags: Optional[list[str]] = Field(default=None, exclude=True)
|
||||
"""Tags to add to the run trace."""
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
|
||||
metadata: Optional[dict[str, Any]] = Field(default=None, exclude=True)
|
||||
"""Metadata to add to the run trace."""
|
||||
custom_get_token_ids: Optional[Callable[[str], List[int]]] = Field(
|
||||
custom_get_token_ids: Optional[Callable[[str], list[int]]] = Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
"""Optional encoder to use for counting tokens."""
|
||||
@ -167,8 +164,8 @@ class BaseLanguageModel(
|
||||
@abstractmethod
|
||||
def generate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[PromptValue],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@ -202,8 +199,8 @@ class BaseLanguageModel(
|
||||
@abstractmethod
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[PromptValue],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@ -235,8 +232,8 @@ class BaseLanguageModel(
|
||||
"""
|
||||
|
||||
def with_structured_output(
|
||||
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
self, schema: Union[dict, type[BaseModel]], **kwargs: Any
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
"""Not implemented on this class."""
|
||||
# Implement this on child class if there is a way of steering the model to
|
||||
# generate responses that match a given schema.
|
||||
@ -267,7 +264,7 @@ class BaseLanguageModel(
|
||||
@abstractmethod
|
||||
def predict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: list[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
@ -313,7 +310,7 @@ class BaseLanguageModel(
|
||||
@abstractmethod
|
||||
async def apredict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: list[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
@ -339,7 +336,7 @@ class BaseLanguageModel(
|
||||
"""Get the identifying parameters."""
|
||||
return self.lc_attributes
|
||||
|
||||
def get_token_ids(self, text: str) -> List[int]:
|
||||
def get_token_ids(self, text: str) -> list[int]:
|
||||
"""Return the ordered ids of the tokens in a text.
|
||||
|
||||
Args:
|
||||
@ -367,7 +364,7 @@ class BaseLanguageModel(
|
||||
"""
|
||||
return len(self.get_token_ids(text))
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
@ -381,7 +378,7 @@ class BaseLanguageModel(
|
||||
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
|
||||
|
||||
@classmethod
|
||||
def _all_required_field_names(cls) -> Set:
|
||||
def _all_required_field_names(cls) -> set:
|
||||
"""DEPRECATED: Kept for backwards compatibility.
|
||||
|
||||
Use get_pydantic_field_names.
|
||||
|
@ -15,11 +15,9 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -223,7 +221,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: Dict) -> Any:
|
||||
def raise_deprecation(cls, values: dict) -> Any:
|
||||
"""Raise deprecation warning if callback_manager is used.
|
||||
|
||||
Args:
|
||||
@ -277,7 +275,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
config = ensure_config(config)
|
||||
@ -300,7 +298,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
config = ensure_config(config)
|
||||
@ -356,7 +354,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[BaseMessageChunk]:
|
||||
if not self._should_stream(async_api=False, **{**kwargs, **{"stream": True}}):
|
||||
@ -426,7 +424,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[BaseMessageChunk]:
|
||||
if not self._should_stream(async_api=True, **{**kwargs, **{"stream": True}}):
|
||||
@ -499,12 +497,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
# --- Custom methods ---
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
|
||||
return {}
|
||||
|
||||
def _get_invocation_params(
|
||||
self,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
params = self.dict()
|
||||
@ -513,7 +511,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
def _get_ls_params(
|
||||
self,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
@ -550,7 +548,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
return ls_params
|
||||
|
||||
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
|
||||
def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str:
|
||||
if self.is_lc_serializable():
|
||||
params = {**kwargs, **{"stop": stop}}
|
||||
param_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
@ -567,12 +565,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[list[BaseMessage]],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
**kwargs: Any,
|
||||
@ -658,12 +656,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[list[BaseMessage]],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
**kwargs: Any,
|
||||
@ -777,8 +775,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
def generate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[PromptValue],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@ -787,8 +785,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[PromptValue],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@ -799,8 +797,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
def _generate_with_cache(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -839,7 +837,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
):
|
||||
chunks: List[ChatGenerationChunk] = []
|
||||
chunks: list[ChatGenerationChunk] = []
|
||||
for chunk in self._stream(messages, stop=stop, **kwargs):
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
if run_manager:
|
||||
@ -876,8 +874,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
async def _agenerate_with_cache(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -916,7 +914,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
):
|
||||
chunks: List[ChatGenerationChunk] = []
|
||||
chunks: list[ChatGenerationChunk] = []
|
||||
async for chunk in self._astream(messages, stop=stop, **kwargs):
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
if run_manager:
|
||||
@ -954,8 +952,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -963,8 +961,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -980,8 +978,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
@ -989,8 +987,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
@ -1017,8 +1015,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
@ -1032,8 +1030,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
async def _call_async(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
@ -1048,7 +1046,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
def call_as_llm(
|
||||
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
self, message: str, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
return self.predict(message, stop=stop, **kwargs)
|
||||
|
||||
@ -1069,7 +1067,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
def predict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: list[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
@ -1099,7 +1097,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
async def apredict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: list[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
@ -1115,7 +1113,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
"""Return a dictionary of the LLM."""
|
||||
starter_dict = dict(self._identifying_params)
|
||||
starter_dict["_type"] = self._llm_type
|
||||
@ -1123,18 +1121,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||
tools: Sequence[Union[Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Union[Dict, Type],
|
||||
schema: Union[Dict, type], # noqa: UP006
|
||||
*,
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: # noqa: UP006
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
@ -1281,8 +1279,8 @@ class SimpleChatModel(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -1294,8 +1292,8 @@ class SimpleChatModel(BaseChatModel):
|
||||
@abstractmethod
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@ -1303,8 +1301,8 @@ class SimpleChatModel(BaseChatModel):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
|
@ -20,8 +20,6 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -76,7 +74,7 @@ def _log_error_once(msg: str) -> None:
|
||||
|
||||
|
||||
def create_base_retry_decorator(
|
||||
error_types: List[Type[BaseException]],
|
||||
error_types: list[type[BaseException]],
|
||||
max_retries: int = 1,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
@ -153,10 +151,10 @@ def _resolve_cache(cache: Union[BaseCache, bool, None]) -> Optional[BaseCache]:
|
||||
|
||||
|
||||
def get_prompts(
|
||||
params: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
params: dict[str, Any],
|
||||
prompts: list[str],
|
||||
cache: Optional[Union[BaseCache, bool, None]] = None,
|
||||
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
|
||||
) -> tuple[dict[int, list], str, list[int], list[str]]:
|
||||
"""Get prompts that are already cached.
|
||||
|
||||
Args:
|
||||
@ -189,10 +187,10 @@ def get_prompts(
|
||||
|
||||
|
||||
async def aget_prompts(
|
||||
params: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
params: dict[str, Any],
|
||||
prompts: list[str],
|
||||
cache: Optional[Union[BaseCache, bool, None]] = None,
|
||||
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
|
||||
) -> tuple[dict[int, list], str, list[int], list[str]]:
|
||||
"""Get prompts that are already cached. Async version.
|
||||
|
||||
Args:
|
||||
@ -225,11 +223,11 @@ async def aget_prompts(
|
||||
|
||||
def update_cache(
|
||||
cache: Union[BaseCache, bool, None],
|
||||
existing_prompts: Dict[int, List],
|
||||
existing_prompts: dict[int, list],
|
||||
llm_string: str,
|
||||
missing_prompt_idxs: List[int],
|
||||
missing_prompt_idxs: list[int],
|
||||
new_results: LLMResult,
|
||||
prompts: List[str],
|
||||
prompts: list[str],
|
||||
) -> Optional[dict]:
|
||||
"""Update the cache and get the LLM output.
|
||||
|
||||
@ -259,11 +257,11 @@ def update_cache(
|
||||
|
||||
async def aupdate_cache(
|
||||
cache: Union[BaseCache, bool, None],
|
||||
existing_prompts: Dict[int, List],
|
||||
existing_prompts: dict[int, list],
|
||||
llm_string: str,
|
||||
missing_prompt_idxs: List[int],
|
||||
missing_prompt_idxs: list[int],
|
||||
new_results: LLMResult,
|
||||
prompts: List[str],
|
||||
prompts: list[str],
|
||||
) -> Optional[dict]:
|
||||
"""Update the cache and get the LLM output. Async version.
|
||||
|
||||
@ -306,7 +304,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: Dict) -> Any:
|
||||
def raise_deprecation(cls, values: dict) -> Any:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
@ -324,7 +322,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
# --- Runnable methods ---
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[str]:
|
||||
def OutputType(self) -> type[str]:
|
||||
"""Get the input type for this runnable."""
|
||||
return str
|
||||
|
||||
@ -343,7 +341,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
def _get_ls_params(
|
||||
self,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
@ -383,7 +381,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
config = ensure_config(config)
|
||||
@ -407,7 +405,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
config = ensure_config(config)
|
||||
@ -425,12 +423,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
@ -472,12 +470,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
if not inputs:
|
||||
return []
|
||||
config = get_config_list(config, len(inputs))
|
||||
@ -521,7 +519,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
if type(self)._stream == BaseLLM._stream:
|
||||
@ -583,7 +581,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[str]:
|
||||
if (
|
||||
@ -649,8 +647,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@ -658,8 +656,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@ -676,7 +674,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
@ -704,7 +702,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
@ -747,9 +745,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
def generate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||
prompts: list[PromptValue],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
@ -757,9 +755,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||
prompts: list[PromptValue],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
@ -769,9 +767,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
def _generate_helper(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]],
|
||||
run_managers: List[CallbackManagerForLLMRun],
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]],
|
||||
run_managers: list[CallbackManagerForLLMRun],
|
||||
new_arg_supported: bool,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@ -802,14 +800,14 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None,
|
||||
*,
|
||||
tags: Optional[Union[List[str], List[List[str]]]] = None,
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
run_name: Optional[Union[str, List[str]]] = None,
|
||||
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None,
|
||||
tags: Optional[Union[list[str], list[list[str]]]] = None,
|
||||
metadata: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None,
|
||||
run_name: Optional[Union[str, list[str]]] = None,
|
||||
run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Pass a sequence of prompts to a model and return generations.
|
||||
@ -987,7 +985,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
@staticmethod
|
||||
def _get_run_ids_list(
|
||||
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]], prompts: list
|
||||
run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]], prompts: list
|
||||
) -> list:
|
||||
if run_id is None:
|
||||
return [None] * len(prompts)
|
||||
@ -1002,9 +1000,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
async def _agenerate_helper(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]],
|
||||
run_managers: List[AsyncCallbackManagerForLLMRun],
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]],
|
||||
run_managers: list[AsyncCallbackManagerForLLMRun],
|
||||
new_arg_supported: bool,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@ -1044,14 +1042,14 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None,
|
||||
*,
|
||||
tags: Optional[Union[List[str], List[List[str]]]] = None,
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
run_name: Optional[Union[str, List[str]]] = None,
|
||||
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None,
|
||||
tags: Optional[Union[list[str], list[list[str]]]] = None,
|
||||
metadata: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None,
|
||||
run_name: Optional[Union[str, list[str]]] = None,
|
||||
run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Asynchronously pass a sequence of prompts to a model and return generations.
|
||||
@ -1239,11 +1237,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Check Cache and run the LLM on the given prompt and input.
|
||||
@ -1287,11 +1285,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
async def _call_async(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Check Cache and run the LLM on the given prompt and input."""
|
||||
@ -1318,7 +1316,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
def predict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: list[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
@ -1344,7 +1342,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
async def apredict_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: list[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
@ -1367,7 +1365,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
"""Return a dictionary of the LLM."""
|
||||
starter_dict = dict(self._identifying_params)
|
||||
starter_dict["_type"] = self._llm_type
|
||||
@ -1443,7 +1441,7 @@ class LLM(BaseLLM):
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@ -1467,7 +1465,7 @@ class LLM(BaseLLM):
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@ -1500,8 +1498,8 @@ class LLM(BaseLLM):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
@ -1520,8 +1518,8 @@ class LLM(BaseLLM):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
|
@ -11,7 +11,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
@ -55,11 +55,11 @@ class BaseMemory(Serializable, ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def memory_variables(self) -> List[str]:
|
||||
def memory_variables(self) -> list[str]:
|
||||
"""The string keys this memory class will add to chain inputs."""
|
||||
|
||||
@abstractmethod
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain.
|
||||
|
||||
Args:
|
||||
@ -69,7 +69,7 @@ class BaseMemory(Serializable, ABC):
|
||||
A dictionary of key-value pairs.
|
||||
"""
|
||||
|
||||
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Async return key-value pairs given the text input to the chain.
|
||||
|
||||
Args:
|
||||
@ -81,7 +81,7 @@ class BaseMemory(Serializable, ABC):
|
||||
return await run_in_executor(None, self.load_memory_variables, inputs)
|
||||
|
||||
@abstractmethod
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
|
||||
"""Save the context of this chain run to memory.
|
||||
|
||||
Args:
|
||||
@ -90,7 +90,7 @@ class BaseMemory(Serializable, ABC):
|
||||
"""
|
||||
|
||||
async def asave_context(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
self, inputs: dict[str, Any], outputs: dict[str, str]
|
||||
) -> None:
|
||||
"""Async save the context of this chain run to memory.
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, field_validator
|
||||
|
||||
@ -19,7 +19,7 @@ class BaseMessage(Serializable):
|
||||
Messages are the inputs and outputs of ChatModels.
|
||||
"""
|
||||
|
||||
content: Union[str, List[Union[str, Dict]]]
|
||||
content: Union[str, list[Union[str, dict]]]
|
||||
"""The string contents of the message."""
|
||||
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
@ -64,7 +64,7 @@ class BaseMessage(Serializable):
|
||||
return id_value
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
||||
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Pass in content as positional arg.
|
||||
|
||||
@ -85,7 +85,7 @@ class BaseMessage(Serializable):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
Default is ["langchain", "schema", "messages"].
|
||||
"""
|
||||
@ -119,9 +119,9 @@ class BaseMessage(Serializable):
|
||||
|
||||
|
||||
def merge_content(
|
||||
first_content: Union[str, List[Union[str, Dict]]],
|
||||
*contents: Union[str, List[Union[str, Dict]]],
|
||||
) -> Union[str, List[Union[str, Dict]]]:
|
||||
first_content: Union[str, list[Union[str, dict]]],
|
||||
*contents: Union[str, list[Union[str, dict]]],
|
||||
) -> Union[str, list[Union[str, dict]]]:
|
||||
"""Merge two message contents.
|
||||
|
||||
Args:
|
||||
@ -163,7 +163,7 @@ class BaseMessageChunk(BaseMessage):
|
||||
"""Message chunk, which can be concatenated with other Message chunks."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
Default is ["langchain", "schema", "messages"].
|
||||
"""
|
||||
@ -242,7 +242,7 @@ def message_to_dict(message: BaseMessage) -> dict:
|
||||
return {"type": message.type, "data": message.model_dump()}
|
||||
|
||||
|
||||
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
|
||||
def messages_to_dict(messages: Sequence[BaseMessage]) -> list[dict]:
|
||||
"""Convert a sequence of Messages to a list of dictionaries.
|
||||
|
||||
Args:
|
||||
|
@ -23,7 +23,6 @@ from typing import (
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
@ -166,7 +165,7 @@ def _message_from_dict(message: dict) -> BaseMessage:
|
||||
raise ValueError(f"Got unexpected message type: {_type}")
|
||||
|
||||
|
||||
def messages_from_dict(messages: Sequence[dict]) -> List[BaseMessage]:
|
||||
def messages_from_dict(messages: Sequence[dict]) -> list[BaseMessage]:
|
||||
"""Convert a sequence of messages from dicts to Message objects.
|
||||
|
||||
Args:
|
||||
@ -208,7 +207,7 @@ def _create_message_from_message_type(
|
||||
content: str,
|
||||
name: Optional[str] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_calls: Optional[list[dict[str, Any]]] = None,
|
||||
id: Optional[str] = None,
|
||||
**additional_kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
@ -230,7 +229,7 @@ def _create_message_from_message_type(
|
||||
ValueError: if the message type is not one of "human", "user", "ai",
|
||||
"assistant", "system", "function", or "tool".
|
||||
"""
|
||||
kwargs: Dict[str, Any] = {}
|
||||
kwargs: dict[str, Any] = {}
|
||||
if name is not None:
|
||||
kwargs["name"] = name
|
||||
if tool_call_id is not None:
|
||||
@ -331,7 +330,7 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
|
||||
|
||||
def convert_to_messages(
|
||||
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
|
||||
) -> List[BaseMessage]:
|
||||
) -> list[BaseMessage]:
|
||||
"""Convert a sequence of messages to a list of messages.
|
||||
|
||||
Args:
|
||||
@ -352,18 +351,18 @@ def _runnable_support(func: Callable) -> Callable:
|
||||
@overload
|
||||
def wrapped(
|
||||
messages: Literal[None] = None, **kwargs: Any
|
||||
) -> Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]]: ...
|
||||
) -> Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]: ...
|
||||
|
||||
@overload
|
||||
def wrapped(
|
||||
messages: Sequence[MessageLikeRepresentation], **kwargs: Any
|
||||
) -> List[BaseMessage]: ...
|
||||
) -> list[BaseMessage]: ...
|
||||
|
||||
def wrapped(
|
||||
messages: Optional[Sequence[MessageLikeRepresentation]] = None, **kwargs: Any
|
||||
) -> Union[
|
||||
List[BaseMessage],
|
||||
Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]],
|
||||
list[BaseMessage],
|
||||
Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]],
|
||||
]:
|
||||
from langchain_core.runnables.base import RunnableLambda
|
||||
|
||||
@ -382,11 +381,11 @@ def filter_messages(
|
||||
*,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[Union[str, Type[BaseMessage]]]] = None,
|
||||
exclude_types: Optional[Sequence[Union[str, Type[BaseMessage]]]] = None,
|
||||
include_types: Optional[Sequence[Union[str, type[BaseMessage]]]] = None,
|
||||
exclude_types: Optional[Sequence[Union[str, type[BaseMessage]]]] = None,
|
||||
include_ids: Optional[Sequence[str]] = None,
|
||||
exclude_ids: Optional[Sequence[str]] = None,
|
||||
) -> List[BaseMessage]:
|
||||
) -> list[BaseMessage]:
|
||||
"""Filter messages based on name, type or id.
|
||||
|
||||
Args:
|
||||
@ -438,7 +437,7 @@ def filter_messages(
|
||||
]
|
||||
""" # noqa: E501
|
||||
messages = convert_to_messages(messages)
|
||||
filtered: List[BaseMessage] = []
|
||||
filtered: list[BaseMessage] = []
|
||||
for msg in messages:
|
||||
if exclude_names and msg.name in exclude_names:
|
||||
continue
|
||||
@ -469,7 +468,7 @@ def merge_message_runs(
|
||||
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
|
||||
*,
|
||||
chunk_separator: str = "\n",
|
||||
) -> List[BaseMessage]:
|
||||
) -> list[BaseMessage]:
|
||||
"""Merge consecutive Messages of the same type.
|
||||
|
||||
**NOTE**: ToolMessages are not merged, as each has a distinct tool call id that
|
||||
@ -539,7 +538,7 @@ def merge_message_runs(
|
||||
if not messages:
|
||||
return []
|
||||
messages = convert_to_messages(messages)
|
||||
merged: List[BaseMessage] = []
|
||||
merged: list[BaseMessage] = []
|
||||
for msg in messages:
|
||||
curr = msg.model_copy(deep=True)
|
||||
last = merged.pop() if merged else None
|
||||
@ -569,21 +568,21 @@ def trim_messages(
|
||||
*,
|
||||
max_tokens: int,
|
||||
token_counter: Union[
|
||||
Callable[[List[BaseMessage]], int],
|
||||
Callable[[list[BaseMessage]], int],
|
||||
Callable[[BaseMessage], int],
|
||||
BaseLanguageModel,
|
||||
],
|
||||
strategy: Literal["first", "last"] = "last",
|
||||
allow_partial: bool = False,
|
||||
end_on: Optional[
|
||||
Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]]
|
||||
Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]]
|
||||
] = None,
|
||||
start_on: Optional[
|
||||
Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]]
|
||||
Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]]
|
||||
] = None,
|
||||
include_system: bool = False,
|
||||
text_splitter: Optional[Union[Callable[[str], List[str]], TextSplitter]] = None,
|
||||
) -> List[BaseMessage]:
|
||||
text_splitter: Optional[Union[Callable[[str], list[str]], TextSplitter]] = None,
|
||||
) -> list[BaseMessage]:
|
||||
"""Trim messages to be below a token count.
|
||||
|
||||
Args:
|
||||
@ -875,13 +874,13 @@ def _first_max_tokens(
|
||||
messages: Sequence[BaseMessage],
|
||||
*,
|
||||
max_tokens: int,
|
||||
token_counter: Callable[[List[BaseMessage]], int],
|
||||
text_splitter: Callable[[str], List[str]],
|
||||
token_counter: Callable[[list[BaseMessage]], int],
|
||||
text_splitter: Callable[[str], list[str]],
|
||||
partial_strategy: Optional[Literal["first", "last"]] = None,
|
||||
end_on: Optional[
|
||||
Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]]
|
||||
Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]]
|
||||
] = None,
|
||||
) -> List[BaseMessage]:
|
||||
) -> list[BaseMessage]:
|
||||
messages = list(messages)
|
||||
idx = 0
|
||||
for i in range(len(messages)):
|
||||
@ -949,17 +948,17 @@ def _last_max_tokens(
|
||||
messages: Sequence[BaseMessage],
|
||||
*,
|
||||
max_tokens: int,
|
||||
token_counter: Callable[[List[BaseMessage]], int],
|
||||
text_splitter: Callable[[str], List[str]],
|
||||
token_counter: Callable[[list[BaseMessage]], int],
|
||||
text_splitter: Callable[[str], list[str]],
|
||||
allow_partial: bool = False,
|
||||
include_system: bool = False,
|
||||
start_on: Optional[
|
||||
Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]]
|
||||
Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]]
|
||||
] = None,
|
||||
end_on: Optional[
|
||||
Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]]
|
||||
Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]]
|
||||
] = None,
|
||||
) -> List[BaseMessage]:
|
||||
) -> list[BaseMessage]:
|
||||
messages = list(messages)
|
||||
if end_on:
|
||||
while messages and not _is_message_type(messages[-1], end_on):
|
||||
@ -984,7 +983,7 @@ def _last_max_tokens(
|
||||
return reversed_[::-1]
|
||||
|
||||
|
||||
_MSG_CHUNK_MAP: Dict[Type[BaseMessage], Type[BaseMessageChunk]] = {
|
||||
_MSG_CHUNK_MAP: dict[type[BaseMessage], type[BaseMessageChunk]] = {
|
||||
HumanMessage: HumanMessageChunk,
|
||||
AIMessage: AIMessageChunk,
|
||||
SystemMessage: SystemMessageChunk,
|
||||
@ -1024,14 +1023,14 @@ def _chunk_to_msg(chunk: BaseMessageChunk) -> BaseMessage:
|
||||
)
|
||||
|
||||
|
||||
def _default_text_splitter(text: str) -> List[str]:
|
||||
def _default_text_splitter(text: str) -> list[str]:
|
||||
splits = text.split("\n")
|
||||
return [s + "\n" for s in splits[:-1]] + splits[-1:]
|
||||
|
||||
|
||||
def _is_message_type(
|
||||
message: BaseMessage,
|
||||
type_: Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]],
|
||||
type_: Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]],
|
||||
) -> bool:
|
||||
types = [type_] if isinstance(type_, (str, type)) else type_
|
||||
types_str = [t for t in types if isinstance(t, str)]
|
||||
|
@ -4,11 +4,8 @@ from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
@ -30,7 +27,7 @@ class BaseLLMOutputParser(Generic[T], ABC):
|
||||
"""Abstract base class for parsing the outputs of a model."""
|
||||
|
||||
@abstractmethod
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
Args:
|
||||
@ -44,7 +41,7 @@ class BaseLLMOutputParser(Generic[T], ABC):
|
||||
"""
|
||||
|
||||
async def aparse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
) -> T:
|
||||
"""Async parse a list of candidate model Generations into a specific format.
|
||||
|
||||
@ -71,7 +68,7 @@ class BaseGenerationOutputParser(
|
||||
return Union[str, AnyMessage]
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[T]:
|
||||
def OutputType(self) -> type[T]:
|
||||
"""Return the output type for the parser."""
|
||||
# even though mypy complains this isn't valid,
|
||||
# it is good enough for pydantic to build the schema from
|
||||
@ -156,7 +153,7 @@ class BaseOutputParser(
|
||||
return Union[str, AnyMessage]
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[T]:
|
||||
def OutputType(self) -> type[T]:
|
||||
"""Return the output type for the parser.
|
||||
|
||||
This property is inferred from the first type argument of the class.
|
||||
@ -218,7 +215,7 @@ class BaseOutputParser(
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
The return value is parsed from only the first Generation in the result, which
|
||||
@ -247,7 +244,7 @@ class BaseOutputParser(
|
||||
"""
|
||||
|
||||
async def aparse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
) -> T:
|
||||
"""Async parse a list of candidate model Generations into a specific format.
|
||||
|
||||
@ -305,7 +302,7 @@ class BaseOutputParser(
|
||||
" This is required for serialization."
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
"""Return dictionary representation of output parser."""
|
||||
output_parser_dict = super().dict(**kwargs)
|
||||
try:
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, List, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Optional, TypeVar, Union
|
||||
|
||||
import jsonpatch # type: ignore[import]
|
||||
import pydantic
|
||||
@ -42,14 +42,14 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
describing the difference between the previous and the current object.
|
||||
"""
|
||||
|
||||
pydantic_object: Annotated[Optional[Type[TBaseModel]], SkipValidation()] = None # type: ignore
|
||||
pydantic_object: Annotated[Optional[type[TBaseModel]], SkipValidation()] = None # type: ignore
|
||||
"""The Pydantic object to use for validation.
|
||||
If None, no validation is performed."""
|
||||
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
def _get_schema(self, pydantic_object: Type[TBaseModel]) -> dict[str, Any]:
|
||||
def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]:
|
||||
if PYDANTIC_MAJOR_VERSION == 2:
|
||||
if issubclass(pydantic_object, pydantic.BaseModel):
|
||||
return pydantic_object.model_json_schema()
|
||||
@ -57,7 +57,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
return pydantic_object.model_json_schema()
|
||||
return pydantic_object.model_json_schema()
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from collections import deque
|
||||
from typing import AsyncIterator, Deque, Iterator, List, TypeVar, Union
|
||||
from typing import AsyncIterator, Iterator, List, TypeVar, Union
|
||||
from typing import Optional as Optional
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
@ -22,7 +22,7 @@ def droplastn(iter: Iterator[T], n: int) -> Iterator[T]:
|
||||
Yields:
|
||||
The elements of the iterator, except the last n elements.
|
||||
"""
|
||||
buffer: Deque[T] = deque()
|
||||
buffer: deque[T] = deque()
|
||||
for item in iter:
|
||||
buffer.append(item)
|
||||
if len(buffer) > n:
|
||||
@ -37,7 +37,7 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
|
||||
return "list"
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> List[str]:
|
||||
def parse(self, text: str) -> list[str]:
|
||||
"""Parse the output of an LLM call.
|
||||
|
||||
Args:
|
||||
@ -60,7 +60,7 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
|
||||
|
||||
def _transform(
|
||||
self, input: Iterator[Union[str, BaseMessage]]
|
||||
) -> Iterator[List[str]]:
|
||||
) -> Iterator[list[str]]:
|
||||
buffer = ""
|
||||
for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
@ -92,7 +92,7 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
|
||||
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
) -> AsyncIterator[List[str]]:
|
||||
) -> AsyncIterator[list[str]]:
|
||||
buffer = ""
|
||||
async for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
@ -136,7 +136,7 @@ class CommaSeparatedListOutputParser(ListOutputParser):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
|
||||
Returns:
|
||||
@ -152,7 +152,7 @@ class CommaSeparatedListOutputParser(ListOutputParser):
|
||||
"eg: `foo, bar, baz` or `foo,bar,baz`"
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
def parse(self, text: str) -> list[str]:
|
||||
"""Parse the output of an LLM call.
|
||||
|
||||
Args:
|
||||
@ -180,7 +180,7 @@ class NumberedListOutputParser(ListOutputParser):
|
||||
"For example: \n\n1. foo\n\n2. bar\n\n3. baz"
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
def parse(self, text: str) -> list[str]:
|
||||
"""Parse the output of an LLM call.
|
||||
|
||||
Args:
|
||||
@ -217,7 +217,7 @@ class MarkdownListOutputParser(ListOutputParser):
|
||||
"""Return the format instructions for the Markdown list output."""
|
||||
return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`"
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
def parse(self, text: str) -> list[str]:
|
||||
"""Parse the output of an LLM call.
|
||||
|
||||
Args:
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Literal, Union
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self
|
||||
@ -69,7 +69,7 @@ class ChatGeneration(Generation):
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "output"]
|
||||
|
||||
@ -86,12 +86,12 @@ class ChatGenerationChunk(ChatGeneration):
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "output"]
|
||||
|
||||
def __add__(
|
||||
self, other: Union[ChatGenerationChunk, List[ChatGenerationChunk]]
|
||||
self, other: Union[ChatGenerationChunk, list[ChatGenerationChunk]]
|
||||
) -> ChatGenerationChunk:
|
||||
if isinstance(other, ChatGenerationChunk):
|
||||
generation_info = merge_dicts(
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
@ -25,7 +25,7 @@ class Generation(Serializable):
|
||||
text: str
|
||||
"""Generated text output."""
|
||||
|
||||
generation_info: Optional[Dict[str, Any]] = None
|
||||
generation_info: Optional[dict[str, Any]] = None
|
||||
"""Raw response from the provider.
|
||||
|
||||
May include things like the reason for finishing or token log probabilities.
|
||||
@ -40,7 +40,7 @@ class Generation(Serializable):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "output"]
|
||||
|
||||
@ -49,7 +49,7 @@ class GenerationChunk(Generation):
|
||||
"""Generation chunk, which can be concatenated with other Generation chunks."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "output"]
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -18,8 +18,8 @@ class LLMResult(BaseModel):
|
||||
wants to return.
|
||||
"""
|
||||
|
||||
generations: List[
|
||||
List[Union[Generation, ChatGeneration, GenerationChunk, ChatGenerationChunk]]
|
||||
generations: list[
|
||||
list[Union[Generation, ChatGeneration, GenerationChunk, ChatGenerationChunk]]
|
||||
]
|
||||
"""Generated outputs.
|
||||
|
||||
@ -45,13 +45,13 @@ class LLMResult(BaseModel):
|
||||
accessing relevant information from standardized fields present in
|
||||
AIMessage.
|
||||
"""
|
||||
run: Optional[List[RunInfo]] = None
|
||||
run: Optional[list[RunInfo]] = None
|
||||
"""List of metadata info for model call for each input."""
|
||||
|
||||
type: Literal["LLMResult"] = "LLMResult" # type: ignore[assignment]
|
||||
"""Type is used exclusively for serialization purposes."""
|
||||
|
||||
def flatten(self) -> List[LLMResult]:
|
||||
def flatten(self) -> list[LLMResult]:
|
||||
"""Flatten generations into a single list.
|
||||
|
||||
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult
|
||||
|
@ -7,7 +7,7 @@ They can be used to represent text, images, or chat message pieces.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Literal, Sequence, cast
|
||||
from typing import Literal, Sequence, cast
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@ -33,7 +33,7 @@ class PromptValue(Serializable, ABC):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
This is used to determine the namespace of the object when serializing.
|
||||
Defaults to ["langchain", "schema", "prompt"].
|
||||
@ -45,7 +45,7 @@ class PromptValue(Serializable, ABC):
|
||||
"""Return prompt value as string."""
|
||||
|
||||
@abstractmethod
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
def to_messages(self) -> list[BaseMessage]:
|
||||
"""Return prompt as a list of Messages."""
|
||||
|
||||
|
||||
@ -57,7 +57,7 @@ class StringPromptValue(PromptValue):
|
||||
type: Literal["StringPromptValue"] = "StringPromptValue"
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
This is used to determine the namespace of the object when serializing.
|
||||
Defaults to ["langchain", "prompts", "base"].
|
||||
@ -68,7 +68,7 @@ class StringPromptValue(PromptValue):
|
||||
"""Return prompt as string."""
|
||||
return self.text
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
def to_messages(self) -> list[BaseMessage]:
|
||||
"""Return prompt as messages."""
|
||||
return [HumanMessage(content=self.text)]
|
||||
|
||||
@ -86,12 +86,12 @@ class ChatPromptValue(PromptValue):
|
||||
"""Return prompt as string."""
|
||||
return get_buffer_string(self.messages)
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
def to_messages(self) -> list[BaseMessage]:
|
||||
"""Return prompt as a list of messages."""
|
||||
return list(self.messages)
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
This is used to determine the namespace of the object when serializing.
|
||||
Defaults to ["langchain", "prompts", "chat"].
|
||||
@ -121,7 +121,7 @@ class ImagePromptValue(PromptValue):
|
||||
"""Return prompt (image URL) as string."""
|
||||
return self.image_url["url"]
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
def to_messages(self) -> list[BaseMessage]:
|
||||
"""Return prompt (image URL) as messages."""
|
||||
return [HumanMessage(content=[cast(dict, self.image_url)])]
|
||||
|
||||
@ -136,7 +136,7 @@ class ChatPromptValueConcrete(ChatPromptValue):
|
||||
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
This is used to determine the namespace of the object when serializing.
|
||||
Defaults to ["langchain", "prompts", "chat"].
|
||||
|
@ -10,10 +10,8 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
@ -45,14 +43,14 @@ class BasePromptTemplate(
|
||||
):
|
||||
"""Base class for all prompt templates, returning a prompt."""
|
||||
|
||||
input_variables: List[str]
|
||||
input_variables: list[str]
|
||||
"""A list of the names of the variables whose values are required as inputs to the
|
||||
prompt."""
|
||||
optional_variables: List[str] = Field(default=[])
|
||||
optional_variables: list[str] = Field(default=[])
|
||||
"""optional_variables: A list of the names of the variables for placeholder
|
||||
or MessagePlaceholder that are optional. These variables are auto inferred
|
||||
from the prompt and user need not provide them."""
|
||||
input_types: Dict[str, Any] = Field(default_factory=dict, exclude=True)
|
||||
input_types: Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006
|
||||
"""A dictionary of the types of the variables the prompt template expects.
|
||||
If not provided, all variables are assumed to be strings."""
|
||||
output_parser: Optional[BaseOutputParser] = None
|
||||
@ -62,9 +60,9 @@ class BasePromptTemplate(
|
||||
|
||||
Partial variables populate the template so that you don't need to
|
||||
pass them in every time you call the prompt."""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None # noqa: UP006
|
||||
"""Metadata to be used for tracing."""
|
||||
tags: Optional[List[str]] = None
|
||||
tags: Optional[list[str]] = None
|
||||
"""Tags to be used for tracing."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
@ -89,7 +87,7 @@ class BasePromptTemplate(
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
Returns ["langchain", "schema", "prompt_template"]."""
|
||||
return ["langchain", "schema", "prompt_template"]
|
||||
@ -115,7 +113,7 @@ class BasePromptTemplate(
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""Get the input schema for the prompt.
|
||||
|
||||
Args:
|
||||
@ -136,7 +134,7 @@ class BasePromptTemplate(
|
||||
field_definitions={**required_input_variables, **optional_input_variables},
|
||||
)
|
||||
|
||||
def _validate_input(self, inner_input: Any) -> Dict:
|
||||
def _validate_input(self, inner_input: Any) -> dict:
|
||||
if not isinstance(inner_input, dict):
|
||||
if len(self.input_variables) == 1:
|
||||
var_name = self.input_variables[0]
|
||||
@ -163,18 +161,18 @@ class BasePromptTemplate(
|
||||
raise KeyError(msg)
|
||||
return inner_input
|
||||
|
||||
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:
|
||||
def _format_prompt_with_error_handling(self, inner_input: dict) -> PromptValue:
|
||||
_inner_input = self._validate_input(inner_input)
|
||||
return self.format_prompt(**_inner_input)
|
||||
|
||||
async def _aformat_prompt_with_error_handling(
|
||||
self, inner_input: Dict
|
||||
self, inner_input: dict
|
||||
) -> PromptValue:
|
||||
_inner_input = self._validate_input(inner_input)
|
||||
return await self.aformat_prompt(**_inner_input)
|
||||
|
||||
def invoke(
|
||||
self, input: Dict, config: Optional[RunnableConfig] = None
|
||||
self, input: dict, config: Optional[RunnableConfig] = None
|
||||
) -> PromptValue:
|
||||
"""Invoke the prompt.
|
||||
|
||||
@ -199,7 +197,7 @@ class BasePromptTemplate(
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Dict, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> PromptValue:
|
||||
"""Async invoke the prompt.
|
||||
|
||||
@ -261,7 +259,7 @@ class BasePromptTemplate(
|
||||
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
|
||||
return type(self)(**prompt_dict)
|
||||
|
||||
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
def _merge_partial_and_user_variables(self, **kwargs: Any) -> dict[str, Any]:
|
||||
# Get partial params:
|
||||
partial_kwargs = {
|
||||
k: v if not callable(v) else v() for k, v in self.partial_variables.items()
|
||||
@ -307,7 +305,7 @@ class BasePromptTemplate(
|
||||
"""Return the prompt type key."""
|
||||
raise NotImplementedError
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
"""Return dictionary representation of prompt.
|
||||
|
||||
Args:
|
||||
@ -369,7 +367,7 @@ class BasePromptTemplate(
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
|
||||
def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> Dict:
|
||||
def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> dict:
|
||||
base_info = {"page_content": doc.page_content, **doc.metadata}
|
||||
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||
if len(missing_metadata) > 0:
|
||||
|
@ -6,12 +6,10 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
@ -60,12 +58,12 @@ class BaseMessagePromptTemplate(Serializable, ABC):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
@abstractmethod
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||
"""Format messages from kwargs. Should return a list of BaseMessages.
|
||||
|
||||
Args:
|
||||
@ -75,7 +73,7 @@ class BaseMessagePromptTemplate(Serializable, ABC):
|
||||
List of BaseMessages.
|
||||
"""
|
||||
|
||||
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||
"""Async format messages from kwargs.
|
||||
Should return a list of BaseMessages.
|
||||
|
||||
@ -89,7 +87,7 @@ class BaseMessagePromptTemplate(Serializable, ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_variables(self) -> List[str]:
|
||||
def input_variables(self) -> list[str]:
|
||||
"""Input variables for this prompt template.
|
||||
|
||||
Returns:
|
||||
@ -210,7 +208,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
Defaults to None."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
@ -223,7 +221,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
variable_name=variable_name, optional=optional, **kwargs
|
||||
)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||
"""Format messages from kwargs.
|
||||
|
||||
Args:
|
||||
@ -251,7 +249,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
return value
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
def input_variables(self) -> list[str]:
|
||||
"""Input variables for this prompt template.
|
||||
|
||||
Returns:
|
||||
@ -292,16 +290,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
"""Additional keyword arguments to pass to the prompt template."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls: Type[MessagePromptTemplateT],
|
||||
cls: type[MessagePromptTemplateT],
|
||||
template: str,
|
||||
template_format: str = "f-string",
|
||||
partial_variables: Optional[Dict[str, Any]] = None,
|
||||
partial_variables: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> MessagePromptTemplateT:
|
||||
"""Create a class from a string template.
|
||||
@ -329,9 +327,9 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
|
||||
@classmethod
|
||||
def from_template_file(
|
||||
cls: Type[MessagePromptTemplateT],
|
||||
cls: type[MessagePromptTemplateT],
|
||||
template_file: Union[str, Path],
|
||||
input_variables: List[str],
|
||||
input_variables: list[str],
|
||||
**kwargs: Any,
|
||||
) -> MessagePromptTemplateT:
|
||||
"""Create a class from a template file.
|
||||
@ -369,7 +367,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
"""
|
||||
return self.format(**kwargs)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||
"""Format messages from kwargs.
|
||||
|
||||
Args:
|
||||
@ -380,7 +378,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
"""
|
||||
return [self.format(**kwargs)]
|
||||
|
||||
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||
"""Async format messages from kwargs.
|
||||
|
||||
Args:
|
||||
@ -392,7 +390,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
return [await self.aformat(**kwargs)]
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
def input_variables(self) -> list[str]:
|
||||
"""
|
||||
Input variables for this prompt template.
|
||||
|
||||
@ -423,7 +421,7 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""Role of the message."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
@ -462,37 +460,37 @@ _StringImageMessagePromptTemplateT = TypeVar(
|
||||
|
||||
|
||||
class _TextTemplateParam(TypedDict, total=False):
|
||||
text: Union[str, Dict]
|
||||
text: Union[str, dict]
|
||||
|
||||
|
||||
class _ImageTemplateParam(TypedDict, total=False):
|
||||
image_url: Union[str, Dict]
|
||||
image_url: Union[str, dict]
|
||||
|
||||
|
||||
class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
"""Human message prompt template. This is a message sent from the user."""
|
||||
|
||||
prompt: Union[
|
||||
StringPromptTemplate, List[Union[StringPromptTemplate, ImagePromptTemplate]]
|
||||
StringPromptTemplate, list[Union[StringPromptTemplate, ImagePromptTemplate]]
|
||||
]
|
||||
"""Prompt template."""
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
"""Additional keyword arguments to pass to the prompt template."""
|
||||
|
||||
_msg_class: Type[BaseMessage]
|
||||
_msg_class: type[BaseMessage]
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls: Type[_StringImageMessagePromptTemplateT],
|
||||
template: Union[str, List[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
|
||||
cls: type[_StringImageMessagePromptTemplateT],
|
||||
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
|
||||
template_format: str = "f-string",
|
||||
*,
|
||||
partial_variables: Optional[Dict[str, Any]] = None,
|
||||
partial_variables: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> _StringImageMessagePromptTemplateT:
|
||||
"""Create a class from a string template.
|
||||
@ -511,7 +509,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
ValueError: If the template is not a string or list of strings.
|
||||
"""
|
||||
if isinstance(template, str):
|
||||
prompt: Union[StringPromptTemplate, List] = PromptTemplate.from_template(
|
||||
prompt: Union[StringPromptTemplate, list] = PromptTemplate.from_template(
|
||||
template,
|
||||
template_format=template_format,
|
||||
partial_variables=partial_variables,
|
||||
@ -574,9 +572,9 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
|
||||
@classmethod
|
||||
def from_template_file(
|
||||
cls: Type[_StringImageMessagePromptTemplateT],
|
||||
cls: type[_StringImageMessagePromptTemplateT],
|
||||
template_file: Union[str, Path],
|
||||
input_variables: List[str],
|
||||
input_variables: list[str],
|
||||
**kwargs: Any,
|
||||
) -> _StringImageMessagePromptTemplateT:
|
||||
"""Create a class from a template file.
|
||||
@ -593,7 +591,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
template = f.read()
|
||||
return cls.from_template(template, input_variables=input_variables, **kwargs)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||
"""Format messages from kwargs.
|
||||
|
||||
Args:
|
||||
@ -604,7 +602,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
"""
|
||||
return [self.format(**kwargs)]
|
||||
|
||||
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||
"""Async format messages from kwargs.
|
||||
|
||||
Args:
|
||||
@ -616,7 +614,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
return [await self.aformat(**kwargs)]
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
def input_variables(self) -> list[str]:
|
||||
"""
|
||||
Input variables for this prompt template.
|
||||
|
||||
@ -642,7 +640,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
content=text, additional_kwargs=self.additional_kwargs
|
||||
)
|
||||
else:
|
||||
content: List = []
|
||||
content: list = []
|
||||
for prompt in self.prompt:
|
||||
inputs = {var: kwargs[var] for var in prompt.input_variables}
|
||||
if isinstance(prompt, StringPromptTemplate):
|
||||
@ -670,7 +668,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
content=text, additional_kwargs=self.additional_kwargs
|
||||
)
|
||||
else:
|
||||
content: List = []
|
||||
content: list = []
|
||||
for prompt in self.prompt:
|
||||
inputs = {var: kwargs[var] for var in prompt.input_variables}
|
||||
if isinstance(prompt, StringPromptTemplate):
|
||||
@ -703,16 +701,16 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
class HumanMessagePromptTemplate(_StringImageMessagePromptTemplate):
|
||||
"""Human message prompt template. This is a message sent from the user."""
|
||||
|
||||
_msg_class: Type[BaseMessage] = HumanMessage
|
||||
_msg_class: type[BaseMessage] = HumanMessage
|
||||
|
||||
|
||||
class AIMessagePromptTemplate(_StringImageMessagePromptTemplate):
|
||||
"""AI message prompt template. This is a message sent from the AI."""
|
||||
|
||||
_msg_class: Type[BaseMessage] = AIMessage
|
||||
_msg_class: type[BaseMessage] = AIMessage
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
@ -722,10 +720,10 @@ class SystemMessagePromptTemplate(_StringImageMessagePromptTemplate):
|
||||
This is a message that is not sent to the user.
|
||||
"""
|
||||
|
||||
_msg_class: Type[BaseMessage] = SystemMessage
|
||||
_msg_class: type[BaseMessage] = SystemMessage
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
@ -734,7 +732,7 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""Base class for chat prompt templates."""
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict:
|
||||
def lc_attributes(self) -> dict:
|
||||
"""
|
||||
Return a list of attribute names that should be included in the
|
||||
serialized kwargs. These attributes must be accepted by the
|
||||
@ -791,10 +789,10 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
return ChatPromptValue(messages=messages)
|
||||
|
||||
@abstractmethod
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||
"""Format kwargs into a list of messages."""
|
||||
|
||||
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||
"""Async format kwargs into a list of messages."""
|
||||
return self.format_messages(**kwargs)
|
||||
|
||||
@ -935,7 +933,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
messages: Annotated[List[MessageLike], SkipValidation()]
|
||||
messages: Annotated[list[MessageLike], SkipValidation()]
|
||||
"""List of messages consisting of either message prompt templates or messages."""
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
@ -999,9 +997,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
]
|
||||
|
||||
# Automatically infer input variables from messages
|
||||
input_vars: Set[str] = set()
|
||||
optional_variables: Set[str] = set()
|
||||
partial_vars: Dict[str, Any] = {}
|
||||
input_vars: set[str] = set()
|
||||
optional_variables: set[str] = set()
|
||||
partial_vars: dict[str, Any] = {}
|
||||
for _message in _messages:
|
||||
if isinstance(_message, MessagesPlaceholder) and _message.optional:
|
||||
partial_vars[_message.variable_name] = []
|
||||
@ -1022,7 +1020,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
cast(Type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
@ -1071,7 +1069,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
messages = values["messages"]
|
||||
input_vars = set()
|
||||
optional_variables = set()
|
||||
input_types: Dict[str, Any] = values.get("input_types", {})
|
||||
input_types: dict[str, Any] = values.get("input_types", {})
|
||||
for message in messages:
|
||||
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
|
||||
input_vars.update(message.input_variables)
|
||||
@ -1125,7 +1123,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
@classmethod
|
||||
@deprecated("0.0.1", alternative="from_messages classmethod", pending=True)
|
||||
def from_role_strings(
|
||||
cls, string_messages: List[Tuple[str, str]]
|
||||
cls, string_messages: list[tuple[str, str]]
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a list of (role, template) tuples.
|
||||
|
||||
@ -1145,7 +1143,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
@classmethod
|
||||
@deprecated("0.0.1", alternative="from_messages classmethod", pending=True)
|
||||
def from_strings(
|
||||
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]]
|
||||
cls, string_messages: list[tuple[type[BaseMessagePromptTemplate], str]]
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a list of (role class, template) tuples.
|
||||
|
||||
@ -1200,7 +1198,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
"""
|
||||
return cls(messages, template_format=template_format)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||
"""Format the chat template into a list of finalized messages.
|
||||
|
||||
Args:
|
||||
@ -1224,7 +1222,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
raise ValueError(f"Unexpected input: {message_template}")
|
||||
return result
|
||||
|
||||
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||
"""Async format the chat template into a list of finalized messages.
|
||||
|
||||
Args:
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@ -31,7 +31,7 @@ from langchain_core.prompts.string import (
|
||||
class _FewShotPromptTemplateMixin(BaseModel):
|
||||
"""Prompt template that contains few shot examples."""
|
||||
|
||||
examples: Optional[List[dict]] = None
|
||||
examples: Optional[list[dict]] = None
|
||||
"""Examples to format into the prompt.
|
||||
Either this or example_selector should be provided."""
|
||||
|
||||
@ -46,7 +46,7 @@ class _FewShotPromptTemplateMixin(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_examples_and_selector(cls, values: Dict) -> Any:
|
||||
def check_examples_and_selector(cls, values: dict) -> Any:
|
||||
"""Check that one and only one of examples/example_selector are provided.
|
||||
|
||||
Args:
|
||||
@ -73,7 +73,7 @@ class _FewShotPromptTemplateMixin(BaseModel):
|
||||
|
||||
return values
|
||||
|
||||
def _get_examples(self, **kwargs: Any) -> List[dict]:
|
||||
def _get_examples(self, **kwargs: Any) -> list[dict]:
|
||||
"""Get the examples to use for formatting the prompt.
|
||||
|
||||
Args:
|
||||
@ -94,7 +94,7 @@ class _FewShotPromptTemplateMixin(BaseModel):
|
||||
"One of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
async def _aget_examples(self, **kwargs: Any) -> List[dict]:
|
||||
async def _aget_examples(self, **kwargs: Any) -> list[dict]:
|
||||
"""Async get the examples to use for formatting the prompt.
|
||||
|
||||
Args:
|
||||
@ -363,7 +363,7 @@ class FewShotChatMessagePromptTemplate(
|
||||
chain.invoke({"input": "What's 3+3?"})
|
||||
"""
|
||||
|
||||
input_variables: List[str] = Field(default_factory=list)
|
||||
input_variables: list[str] = Field(default_factory=list)
|
||||
"""A list of the names of the variables the prompt template will use
|
||||
to pass to the example_selector, if provided."""
|
||||
|
||||
@ -380,7 +380,7 @@ class FewShotChatMessagePromptTemplate(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||
"""Format kwargs into a list of messages.
|
||||
|
||||
Args:
|
||||
@ -402,7 +402,7 @@ class FewShotChatMessagePromptTemplate(
|
||||
]
|
||||
return messages
|
||||
|
||||
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||
"""Async format kwargs into a list of messages.
|
||||
|
||||
Args:
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@ -54,13 +54,13 @@ class PromptTemplate(StringPromptTemplate):
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
def lc_attributes(self) -> dict[str, Any]:
|
||||
return {
|
||||
"template_format": self.template_format,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "prompt"]
|
||||
|
||||
@ -76,7 +76,7 @@ class PromptTemplate(StringPromptTemplate):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_init_validation(cls, values: Dict) -> Any:
|
||||
def pre_init_validation(cls, values: dict) -> Any:
|
||||
"""Check that template and input variables are consistent."""
|
||||
if values.get("template") is None:
|
||||
# Will let pydantic fail with a ValidationError if template
|
||||
@ -183,9 +183,9 @@ class PromptTemplate(StringPromptTemplate):
|
||||
@classmethod
|
||||
def from_examples(
|
||||
cls,
|
||||
examples: List[str],
|
||||
examples: list[str],
|
||||
suffix: str,
|
||||
input_variables: List[str],
|
||||
input_variables: list[str],
|
||||
example_separator: str = "\n\n",
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
@ -215,7 +215,7 @@ class PromptTemplate(StringPromptTemplate):
|
||||
def from_file(
|
||||
cls,
|
||||
template_file: Union[str, Path],
|
||||
input_variables: Optional[List[str]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
encoding: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> PromptTemplate:
|
||||
@ -249,7 +249,7 @@ class PromptTemplate(StringPromptTemplate):
|
||||
template: str,
|
||||
*,
|
||||
template_format: str = "f-string",
|
||||
partial_variables: Optional[Dict[str, Any]] = None,
|
||||
partial_variables: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> PromptTemplate:
|
||||
"""Load a prompt template from a template.
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from string import Formatter
|
||||
from typing import Any, Callable, Dict, List, Set, Tuple, Type
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
@ -60,7 +60,7 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
return SandboxedEnvironment().from_string(template).render(**kwargs)
|
||||
|
||||
|
||||
def validate_jinja2(template: str, input_variables: List[str]) -> None:
|
||||
def validate_jinja2(template: str, input_variables: list[str]) -> None:
|
||||
"""
|
||||
Validate that the input variables are valid for the template.
|
||||
Issues a warning if missing or extra variables are found.
|
||||
@ -85,7 +85,7 @@ def validate_jinja2(template: str, input_variables: List[str]) -> None:
|
||||
warnings.warn(warning_message.strip(), stacklevel=7)
|
||||
|
||||
|
||||
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
||||
def _get_jinja2_variables_from_template(template: str) -> set[str]:
|
||||
try:
|
||||
from jinja2 import Environment, meta
|
||||
except ImportError as e:
|
||||
@ -114,7 +114,7 @@ def mustache_formatter(template: str, **kwargs: Any) -> str:
|
||||
|
||||
def mustache_template_vars(
|
||||
template: str,
|
||||
) -> Set[str]:
|
||||
) -> set[str]:
|
||||
"""Get the variables from a mustache template.
|
||||
|
||||
Args:
|
||||
@ -123,7 +123,7 @@ def mustache_template_vars(
|
||||
Returns:
|
||||
The variables from the template.
|
||||
"""
|
||||
vars: Set[str] = set()
|
||||
vars: set[str] = set()
|
||||
section_depth = 0
|
||||
for type, key in mustache.tokenize(template):
|
||||
if type == "end":
|
||||
@ -144,7 +144,7 @@ Defs = Dict[str, "Defs"]
|
||||
|
||||
def mustache_schema(
|
||||
template: str,
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""Get the variables from a mustache template.
|
||||
|
||||
Args:
|
||||
@ -154,8 +154,8 @@ def mustache_schema(
|
||||
The variables from the template as a Pydantic model.
|
||||
"""
|
||||
fields = {}
|
||||
prefix: Tuple[str, ...] = ()
|
||||
section_stack: List[Tuple[str, ...]] = []
|
||||
prefix: tuple[str, ...] = ()
|
||||
section_stack: list[tuple[str, ...]] = []
|
||||
for type, key in mustache.tokenize(template):
|
||||
if key == ".":
|
||||
continue
|
||||
@ -178,7 +178,7 @@ def mustache_schema(
|
||||
return _create_model_recursive("PromptInput", defs)
|
||||
|
||||
|
||||
def _create_model_recursive(name: str, defs: Defs) -> Type:
|
||||
def _create_model_recursive(name: str, defs: Defs) -> type:
|
||||
return create_model( # type: ignore[call-overload]
|
||||
name,
|
||||
**{
|
||||
@ -188,20 +188,20 @@ def _create_model_recursive(name: str, defs: Defs) -> Type:
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
||||
DEFAULT_FORMATTER_MAPPING: dict[str, Callable] = {
|
||||
"f-string": formatter.format,
|
||||
"mustache": mustache_formatter,
|
||||
"jinja2": jinja2_formatter,
|
||||
}
|
||||
|
||||
DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
|
||||
DEFAULT_VALIDATOR_MAPPING: dict[str, Callable] = {
|
||||
"f-string": formatter.validate_input_variables,
|
||||
"jinja2": validate_jinja2,
|
||||
}
|
||||
|
||||
|
||||
def check_valid_template(
|
||||
template: str, template_format: str, input_variables: List[str]
|
||||
template: str, template_format: str, input_variables: list[str]
|
||||
) -> None:
|
||||
"""Check that template string is valid.
|
||||
|
||||
@ -230,7 +230,7 @@ def check_valid_template(
|
||||
) from exc
|
||||
|
||||
|
||||
def get_template_variables(template: str, template_format: str) -> List[str]:
|
||||
def get_template_variables(template: str, template_format: str) -> list[str]:
|
||||
"""Get the variables from the template.
|
||||
|
||||
Args:
|
||||
@ -262,7 +262,7 @@ class StringPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""String prompt that exposes the format method, returning a prompt."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "base"]
|
||||
|
||||
|
@ -24,7 +24,7 @@ from __future__ import annotations
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import TypedDict
|
||||
@ -132,14 +132,14 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
|
||||
_new_arg_supported: bool = False
|
||||
_expects_other_args: bool = False
|
||||
tags: Optional[List[str]] = None
|
||||
tags: Optional[list[str]] = None
|
||||
"""Optional list of tags associated with the retriever. Defaults to None.
|
||||
These tags will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a retriever with its
|
||||
use case.
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
"""Optional metadata associated with the retriever. Defaults to None.
|
||||
This metadata will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
@ -200,7 +200,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Invoke the retriever to get relevant documents.
|
||||
|
||||
Main entry point for synchronous retriever invocations.
|
||||
@ -263,7 +263,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
input: str,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Asynchronously invoke the retriever to get relevant documents.
|
||||
|
||||
Main entry point for asynchronous retriever invocations.
|
||||
@ -324,7 +324,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
@abstractmethod
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
|
||||
Args:
|
||||
@ -336,7 +336,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
|
||||
Args:
|
||||
@ -358,11 +358,11 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
query: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Retrieve documents relevant to a query.
|
||||
|
||||
Users should favor using `.invoke` or `.batch` rather than
|
||||
@ -402,11 +402,11 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
query: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
|
||||
Users should favor using `.ainvoke` or `.abatch` rather than
|
||||
|
@ -27,8 +27,6 @@ from typing import (
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -273,7 +271,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return name_
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
def InputType(self) -> type[Input]:
|
||||
"""The type of input this Runnable accepts specified as a type annotation."""
|
||||
# First loop through all parent classes and if any of them is
|
||||
# a pydantic model, we will pick up the generic parameterization
|
||||
@ -298,7 +296,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
def OutputType(self) -> type[Output]:
|
||||
"""The type of output this Runnable produces specified as a type annotation."""
|
||||
# First loop through bases -- this will help generic
|
||||
# any pydantic models.
|
||||
@ -319,13 +317,13 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
def input_schema(self) -> type[BaseModel]:
|
||||
"""The type of input this Runnable accepts specified as a pydantic model."""
|
||||
return self.get_input_schema()
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""Get a pydantic model that can be used to validate input to the Runnable.
|
||||
|
||||
Runnables that leverage the configurable_fields and configurable_alternatives
|
||||
@ -360,7 +358,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
def get_input_jsonschema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Get a JSON schema that represents the input to the Runnable.
|
||||
|
||||
Args:
|
||||
@ -387,13 +385,13 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return self.get_input_schema(config).model_json_schema()
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
def output_schema(self) -> type[BaseModel]:
|
||||
"""The type of output this Runnable produces specified as a pydantic model."""
|
||||
return self.get_output_schema()
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""Get a pydantic model that can be used to validate output to the Runnable.
|
||||
|
||||
Runnables that leverage the configurable_fields and configurable_alternatives
|
||||
@ -428,7 +426,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
def get_output_jsonschema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Get a JSON schema that represents the output of the Runnable.
|
||||
|
||||
Args:
|
||||
@ -455,13 +453,13 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return self.get_output_schema(config).model_json_schema()
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
"""List configurable fields for this Runnable."""
|
||||
return []
|
||||
|
||||
def config_schema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""The type of config this Runnable accepts specified as a pydantic model.
|
||||
|
||||
To mark a field as configurable, see the `configurable_fields`
|
||||
@ -509,7 +507,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
def get_config_jsonschema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Get a JSON schema that represents the output of the Runnable.
|
||||
|
||||
Args:
|
||||
@ -544,7 +542,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
def get_prompts(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> List[BasePromptTemplate]:
|
||||
) -> list[BasePromptTemplate]:
|
||||
"""Return a list of prompts used by this Runnable."""
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
|
||||
@ -614,7 +612,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
return RunnableSequence(self, *others, name=name)
|
||||
|
||||
def pick(self, keys: Union[str, List[str]]) -> RunnableSerializable[Any, Any]:
|
||||
def pick(self, keys: Union[str, list[str]]) -> RunnableSerializable[Any, Any]:
|
||||
"""Pick keys from the dict output of this Runnable.
|
||||
|
||||
Pick single key:
|
||||
@ -670,11 +668,11 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
def assign(
|
||||
self,
|
||||
**kwargs: Union[
|
||||
Runnable[Dict[str, Any], Any],
|
||||
Callable[[Dict[str, Any]], Any],
|
||||
Runnable[dict[str, Any], Any],
|
||||
Callable[[dict[str, Any]], Any],
|
||||
Mapping[
|
||||
str,
|
||||
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
|
||||
Union[Runnable[dict[str, Any], Any], Callable[[dict[str, Any]], Any]],
|
||||
],
|
||||
],
|
||||
) -> RunnableSerializable[Any, Any]:
|
||||
@ -710,7 +708,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
from langchain_core.runnables.passthrough import RunnableAssign
|
||||
|
||||
return self | RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
|
||||
return self | RunnableAssign(RunnableParallel[dict[str, Any]](kwargs))
|
||||
|
||||
""" --- Public API --- """
|
||||
|
||||
@ -744,12 +742,12 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[Input],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
"""Default implementation runs invoke in parallel using a thread pool executor.
|
||||
|
||||
The default implementation of batch works well for IO bound runnables.
|
||||
@ -786,7 +784,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
*,
|
||||
return_exceptions: Literal[False] = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Tuple[int, Output]]: ...
|
||||
) -> Iterator[tuple[int, Output]]: ...
|
||||
|
||||
@overload
|
||||
def batch_as_completed(
|
||||
@ -796,7 +794,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
*,
|
||||
return_exceptions: Literal[True],
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Tuple[int, Union[Output, Exception]]]: ...
|
||||
) -> Iterator[tuple[int, Union[Output, Exception]]]: ...
|
||||
|
||||
def batch_as_completed(
|
||||
self,
|
||||
@ -805,7 +803,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
|
||||
) -> Iterator[tuple[int, Union[Output, Exception]]]:
|
||||
"""Run invoke in parallel on a list of inputs,
|
||||
yielding results as they complete."""
|
||||
|
||||
@ -816,7 +814,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
def invoke(
|
||||
i: int, input: Input, config: RunnableConfig
|
||||
) -> Tuple[int, Union[Output, Exception]]:
|
||||
) -> tuple[int, Union[Output, Exception]]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
out: Union[Output, Exception] = self.invoke(input, config, **kwargs)
|
||||
@ -848,12 +846,12 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[Input],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
"""Default implementation runs ainvoke in parallel using asyncio.gather.
|
||||
|
||||
The default implementation of batch works well for IO bound runnables.
|
||||
@ -902,7 +900,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
*,
|
||||
return_exceptions: Literal[False] = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Tuple[int, Output]]: ...
|
||||
) -> AsyncIterator[tuple[int, Output]]: ...
|
||||
|
||||
@overload
|
||||
def abatch_as_completed(
|
||||
@ -912,7 +910,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
*,
|
||||
return_exceptions: Literal[True],
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: ...
|
||||
) -> AsyncIterator[tuple[int, Union[Output, Exception]]]: ...
|
||||
|
||||
async def abatch_as_completed(
|
||||
self,
|
||||
@ -921,7 +919,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
|
||||
) -> AsyncIterator[tuple[int, Union[Output, Exception]]]:
|
||||
"""Run ainvoke in parallel on a list of inputs,
|
||||
yielding results as they complete.
|
||||
|
||||
@ -947,7 +945,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
async def ainvoke(
|
||||
i: int, input: Input, config: RunnableConfig
|
||||
) -> Tuple[int, Union[Output, Exception]]:
|
||||
) -> tuple[int, Union[Output, Exception]]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
out: Union[Output, Exception] = await self.ainvoke(
|
||||
@ -1699,8 +1697,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
def with_types(
|
||||
self,
|
||||
*,
|
||||
input_type: Optional[Type[Input]] = None,
|
||||
output_type: Optional[Type[Output]] = None,
|
||||
input_type: Optional[type[Input]] = None,
|
||||
output_type: Optional[type[Output]] = None,
|
||||
) -> Runnable[Input, Output]:
|
||||
"""
|
||||
Bind input and output types to a Runnable, returning a new Runnable.
|
||||
@ -1722,7 +1720,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
def with_retry(
|
||||
self,
|
||||
*,
|
||||
retry_if_exception_type: Tuple[Type[BaseException], ...] = (Exception,),
|
||||
retry_if_exception_type: tuple[type[BaseException], ...] = (Exception,),
|
||||
wait_exponential_jitter: bool = True,
|
||||
stop_after_attempt: int = 3,
|
||||
) -> Runnable[Input, Output]:
|
||||
@ -1789,7 +1787,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
max_attempt_number=stop_after_attempt,
|
||||
)
|
||||
|
||||
def map(self) -> Runnable[List[Input], List[Output]]:
|
||||
def map(self) -> Runnable[list[Input], list[Output]]:
|
||||
"""
|
||||
Return a new Runnable that maps a list of inputs to a list of outputs,
|
||||
by calling invoke() with each input.
|
||||
@ -1815,7 +1813,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
self,
|
||||
fallbacks: Sequence[Runnable[Input, Output]],
|
||||
*,
|
||||
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,),
|
||||
exceptions_to_handle: tuple[type[BaseException], ...] = (Exception,),
|
||||
exception_key: Optional[str] = None,
|
||||
) -> RunnableWithFallbacksT[Input, Output]:
|
||||
"""Add fallbacks to a Runnable, returning a new Runnable.
|
||||
@ -1893,7 +1891,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
serialized: Optional[Dict[str, Any]] = None,
|
||||
serialized: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
@ -1942,7 +1940,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
serialized: Optional[Dict[str, Any]] = None,
|
||||
serialized: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
@ -1977,23 +1975,23 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
def _batch_with_config(
|
||||
self,
|
||||
func: Union[
|
||||
Callable[[List[Input]], List[Union[Exception, Output]]],
|
||||
Callable[[list[Input]], list[Union[Exception, Output]]],
|
||||
Callable[
|
||||
[List[Input], List[CallbackManagerForChainRun]],
|
||||
List[Union[Exception, Output]],
|
||||
[list[Input], list[CallbackManagerForChainRun]],
|
||||
list[Union[Exception, Output]],
|
||||
],
|
||||
Callable[
|
||||
[List[Input], List[CallbackManagerForChainRun], List[RunnableConfig]],
|
||||
List[Union[Exception, Output]],
|
||||
[list[Input], list[CallbackManagerForChainRun], list[RunnableConfig]],
|
||||
list[Union[Exception, Output]],
|
||||
],
|
||||
],
|
||||
input: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
input: list[Input],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
run_type: Optional[str] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||
if not input:
|
||||
@ -2045,27 +2043,27 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
async def _abatch_with_config(
|
||||
self,
|
||||
func: Union[
|
||||
Callable[[List[Input]], Awaitable[List[Union[Exception, Output]]]],
|
||||
Callable[[list[Input]], Awaitable[list[Union[Exception, Output]]]],
|
||||
Callable[
|
||||
[List[Input], List[AsyncCallbackManagerForChainRun]],
|
||||
Awaitable[List[Union[Exception, Output]]],
|
||||
[list[Input], list[AsyncCallbackManagerForChainRun]],
|
||||
Awaitable[list[Union[Exception, Output]]],
|
||||
],
|
||||
Callable[
|
||||
[
|
||||
List[Input],
|
||||
List[AsyncCallbackManagerForChainRun],
|
||||
List[RunnableConfig],
|
||||
list[Input],
|
||||
list[AsyncCallbackManagerForChainRun],
|
||||
list[RunnableConfig],
|
||||
],
|
||||
Awaitable[List[Union[Exception, Output]]],
|
||||
Awaitable[list[Union[Exception, Output]]],
|
||||
],
|
||||
],
|
||||
input: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
input: list[Input],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
run_type: Optional[str] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||
if not input:
|
||||
@ -2073,7 +2071,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
configs = get_config_list(config, len(input))
|
||||
callback_managers = [get_async_callback_manager_for_config(c) for c in configs]
|
||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
*(
|
||||
callback_manager.on_chain_start(
|
||||
None,
|
||||
@ -2106,7 +2104,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
raise
|
||||
else:
|
||||
first_exception: Optional[Exception] = None
|
||||
coros: List[Awaitable[None]] = []
|
||||
coros: list[Awaitable[None]] = []
|
||||
for run_manager, out in zip(run_managers, output):
|
||||
if isinstance(out, Exception):
|
||||
first_exception = first_exception or out
|
||||
@ -2333,11 +2331,11 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
@beta_decorator.beta(message="This API is in beta and may change in the future.")
|
||||
def as_tool(
|
||||
self,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
args_schema: Optional[type[BaseModel]] = None,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
arg_types: Optional[Dict[str, Type]] = None,
|
||||
arg_types: Optional[dict[str, type]] = None,
|
||||
) -> BaseTool:
|
||||
"""Create a BaseTool from a Runnable.
|
||||
|
||||
@ -2573,8 +2571,8 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
|
||||
|
||||
def _seq_input_schema(
|
||||
steps: List[Runnable[Any, Any]], config: Optional[RunnableConfig]
|
||||
) -> Type[BaseModel]:
|
||||
steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig]
|
||||
) -> type[BaseModel]:
|
||||
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick
|
||||
|
||||
first = steps[0]
|
||||
@ -2599,8 +2597,8 @@ def _seq_input_schema(
|
||||
|
||||
|
||||
def _seq_output_schema(
|
||||
steps: List[Runnable[Any, Any]], config: Optional[RunnableConfig]
|
||||
) -> Type[BaseModel]:
|
||||
steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig]
|
||||
) -> type[BaseModel]:
|
||||
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick
|
||||
|
||||
last = steps[-1]
|
||||
@ -2730,7 +2728,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
# the last type.
|
||||
first: Runnable[Input, Any]
|
||||
"""The first Runnable in the sequence."""
|
||||
middle: List[Runnable[Any, Any]] = Field(default_factory=list)
|
||||
middle: list[Runnable[Any, Any]] = Field(default_factory=list)
|
||||
"""The middle Runnables in the sequence."""
|
||||
last: Runnable[Any, Output]
|
||||
"""The last Runnable in the sequence."""
|
||||
@ -2740,7 +2738,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
*steps: RunnableLike,
|
||||
name: Optional[str] = None,
|
||||
first: Optional[Runnable[Any, Any]] = None,
|
||||
middle: Optional[List[Runnable[Any, Any]]] = None,
|
||||
middle: Optional[list[Runnable[Any, Any]]] = None,
|
||||
last: Optional[Runnable[Any, Any]] = None,
|
||||
) -> None:
|
||||
"""Create a new RunnableSequence.
|
||||
@ -2755,7 +2753,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
Raises:
|
||||
ValueError: If the sequence has less than 2 steps.
|
||||
"""
|
||||
steps_flat: List[Runnable] = []
|
||||
steps_flat: list[Runnable] = []
|
||||
if not steps:
|
||||
if first is not None and last is not None:
|
||||
steps_flat = [first] + (middle or []) + [last]
|
||||
@ -2776,12 +2774,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@property
|
||||
def steps(self) -> List[Runnable[Any, Any]]:
|
||||
def steps(self) -> list[Runnable[Any, Any]]:
|
||||
"""All the Runnables that make up the sequence in order.
|
||||
|
||||
Returns:
|
||||
@ -2804,18 +2802,18 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
def InputType(self) -> type[Input]:
|
||||
"""The type of the input to the Runnable."""
|
||||
return self.first.InputType
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
def OutputType(self) -> type[Output]:
|
||||
"""The type of the output of the Runnable."""
|
||||
return self.last.OutputType
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""Get the input schema of the Runnable.
|
||||
|
||||
Args:
|
||||
@ -2828,7 +2826,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""Get the output schema of the Runnable.
|
||||
|
||||
Args:
|
||||
@ -2840,7 +2838,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
return _seq_output_schema(self.steps, config)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
"""Get the config specs of the Runnable.
|
||||
|
||||
Returns:
|
||||
@ -2862,8 +2860,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
[tup for tup in all_specs if tup[0].id.startswith(CONTEXT_CONFIG_PREFIX)],
|
||||
lambda x: x[1],
|
||||
)
|
||||
next_deps: Set[str] = set()
|
||||
deps_by_pos: Dict[int, Set[str]] = {}
|
||||
next_deps: set[str] = set()
|
||||
deps_by_pos: dict[int, set[str]] = {}
|
||||
for pos, specs in specs_by_pos:
|
||||
deps_by_pos[pos] = next_deps
|
||||
next_deps = next_deps | {spec[0].id for spec in specs}
|
||||
@ -3065,12 +3063,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[Input],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
from langchain_core.beta.runnables.context import config_with_context
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
@ -3111,7 +3109,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
# Track which inputs (by index) failed so far
|
||||
# If an input has failed it will be present in this map,
|
||||
# and the value will be the exception that was raised.
|
||||
failed_inputs_map: Dict[int, Exception] = {}
|
||||
failed_inputs_map: dict[int, Exception] = {}
|
||||
for stepidx, step in enumerate(self.steps):
|
||||
# Assemble the original indexes of the remaining inputs
|
||||
# (i.e. the ones that haven't failed yet)
|
||||
@ -3191,12 +3189,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[Input],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
from langchain_core.beta.runnables.context import aconfig_with_context
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
@ -3221,7 +3219,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
for config in configs
|
||||
]
|
||||
# start the root runs, one per input
|
||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
*(
|
||||
cm.on_chain_start(
|
||||
None,
|
||||
@ -3240,7 +3238,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
# Track which inputs (by index) failed so far
|
||||
# If an input has failed it will be present in this map,
|
||||
# and the value will be the exception that was raised.
|
||||
failed_inputs_map: Dict[int, Exception] = {}
|
||||
failed_inputs_map: dict[int, Exception] = {}
|
||||
for stepidx, step in enumerate(self.steps):
|
||||
# Assemble the original indexes of the remaining inputs
|
||||
# (i.e. the ones that haven't failed yet)
|
||||
@ -3305,7 +3303,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
raise
|
||||
else:
|
||||
first_exception: Optional[Exception] = None
|
||||
coros: List[Awaitable[None]] = []
|
||||
coros: list[Awaitable[None]] = []
|
||||
for run_manager, out in zip(run_managers, inputs):
|
||||
if isinstance(out, Exception):
|
||||
first_exception = first_exception or out
|
||||
@ -3533,7 +3531,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@ -3567,7 +3565,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""Get the input schema of the Runnable.
|
||||
|
||||
Args:
|
||||
@ -3596,7 +3594,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""Get the output schema of the Runnable.
|
||||
|
||||
Args:
|
||||
@ -3609,7 +3607,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
return create_model_v2(self.get_name("Output"), field_definitions=fields)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
"""Get the config specs of the Runnable.
|
||||
|
||||
Returns:
|
||||
@ -3662,7 +3660,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
@ -3724,7 +3722,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
# setup callbacks
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
@ -3829,7 +3827,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
input: Iterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
yield from self._transform_stream_with_config(
|
||||
input, self._transform, config, **kwargs
|
||||
)
|
||||
@ -3839,7 +3837,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
yield from self.transform(iter([input]), config)
|
||||
|
||||
async def _atransform(
|
||||
@ -3898,7 +3896,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
input: AsyncIterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
async for chunk in self._atransform_stream_with_config(
|
||||
input, self._atransform, config, **kwargs
|
||||
):
|
||||
@ -3909,7 +3907,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
async def input_aiter() -> AsyncIterator[Input]:
|
||||
yield input
|
||||
|
||||
@ -4064,7 +4062,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
# Override the default implementation.
|
||||
# For a runnable generator, we need to bring to provide the
|
||||
# module of the underlying function when creating the model.
|
||||
@ -4100,7 +4098,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
# Override the default implementation.
|
||||
# For a runnable generator, we need to bring to provide the
|
||||
# module of the underlying function when creating the model.
|
||||
@ -4346,7 +4344,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""The pydantic schema for the input to this Runnable.
|
||||
|
||||
Args:
|
||||
@ -4414,7 +4412,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
# Override the default implementation.
|
||||
# For a runnable lambda, we need to bring to provide the
|
||||
# module of the underlying function when creating the model.
|
||||
@ -4435,7 +4433,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
)
|
||||
|
||||
@property
|
||||
def deps(self) -> List[Runnable]:
|
||||
def deps(self) -> list[Runnable]:
|
||||
"""The dependencies of this Runnable.
|
||||
|
||||
Returns:
|
||||
@ -4449,7 +4447,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
else:
|
||||
objects = []
|
||||
|
||||
deps: List[Runnable] = []
|
||||
deps: list[Runnable] = []
|
||||
for obj in objects:
|
||||
if isinstance(obj, Runnable):
|
||||
deps.append(obj)
|
||||
@ -4458,7 +4456,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
return deps
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec for dep in self.deps for spec in dep.config_specs
|
||||
)
|
||||
@ -4944,7 +4942,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
return create_model_v2(
|
||||
self.get_name("Input"),
|
||||
root=(
|
||||
@ -4962,12 +4960,12 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
)
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[List[Output]]:
|
||||
def OutputType(self) -> type[list[Output]]:
|
||||
return List[self.bound.OutputType] # type: ignore[name-defined]
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
schema = self.bound.get_output_schema(config)
|
||||
return create_model_v2(
|
||||
self.get_name("Output"),
|
||||
@ -4983,7 +4981,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
return self.bound.config_specs
|
||||
|
||||
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
||||
@ -4994,42 +4992,42 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
inputs: list[Input],
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
configs = [
|
||||
patch_config(config, callbacks=run_manager.get_child()) for _ in inputs
|
||||
]
|
||||
return self.bound.batch(inputs, configs, **kwargs)
|
||||
|
||||
def invoke(
|
||||
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> List[Output]:
|
||||
self, input: list[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> list[Output]:
|
||||
return self._call_with_config(self._invoke, input, config, **kwargs)
|
||||
|
||||
async def _ainvoke(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
inputs: list[Input],
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
configs = [
|
||||
patch_config(config, callbacks=run_manager.get_child()) for _ in inputs
|
||||
]
|
||||
return await self.bound.abatch(inputs, configs, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> List[Output]:
|
||||
self, input: list[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> list[Output]:
|
||||
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
|
||||
|
||||
async def astream_events(
|
||||
@ -5074,7 +5072,7 @@ class RunnableEach(RunnableEachBase[Input, Output]):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@ -5181,7 +5179,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
config: RunnableConfig = Field(default_factory=dict)
|
||||
"""The config to bind to the underlying Runnable."""
|
||||
|
||||
config_factories: List[Callable[[RunnableConfig], RunnableConfig]] = Field(
|
||||
config_factories: list[Callable[[RunnableConfig], RunnableConfig]] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
"""The config factories to bind to the underlying Runnable."""
|
||||
@ -5210,10 +5208,10 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
kwargs: Optional[Mapping[str, Any]] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config_factories: Optional[
|
||||
List[Callable[[RunnableConfig], RunnableConfig]]
|
||||
list[Callable[[RunnableConfig], RunnableConfig]]
|
||||
] = None,
|
||||
custom_input_type: Optional[Union[Type[Input], BaseModel]] = None,
|
||||
custom_output_type: Optional[Union[Type[Output], BaseModel]] = None,
|
||||
custom_input_type: Optional[Union[type[Input], BaseModel]] = None,
|
||||
custom_output_type: Optional[Union[type[Output], BaseModel]] = None,
|
||||
**other_kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a RunnableBinding from a Runnable and kwargs.
|
||||
@ -5255,7 +5253,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
return self.bound.get_name(suffix, name=name)
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
def InputType(self) -> type[Input]:
|
||||
return (
|
||||
cast(Type[Input], self.custom_input_type)
|
||||
if self.custom_input_type is not None
|
||||
@ -5263,7 +5261,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
def OutputType(self) -> type[Output]:
|
||||
return (
|
||||
cast(Type[Output], self.custom_output_type)
|
||||
if self.custom_output_type is not None
|
||||
@ -5272,20 +5270,20 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
if self.custom_input_type is not None:
|
||||
return super().get_input_schema(config)
|
||||
return self.bound.get_input_schema(merge_configs(self.config, config))
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
if self.custom_output_type is not None:
|
||||
return super().get_output_schema(config)
|
||||
return self.bound.get_output_schema(merge_configs(self.config, config))
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
return self.bound.config_specs
|
||||
|
||||
def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
|
||||
@ -5296,7 +5294,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@ -5330,12 +5328,12 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[Input],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
if isinstance(config, list):
|
||||
configs = cast(
|
||||
List[RunnableConfig],
|
||||
@ -5352,12 +5350,12 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[Input],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
if isinstance(config, list):
|
||||
configs = cast(
|
||||
List[RunnableConfig],
|
||||
@ -5380,7 +5378,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
*,
|
||||
return_exceptions: Literal[False] = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Tuple[int, Output]]: ...
|
||||
) -> Iterator[tuple[int, Output]]: ...
|
||||
|
||||
@overload
|
||||
def batch_as_completed(
|
||||
@ -5390,7 +5388,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
*,
|
||||
return_exceptions: Literal[True],
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Tuple[int, Union[Output, Exception]]]: ...
|
||||
) -> Iterator[tuple[int, Union[Output, Exception]]]: ...
|
||||
|
||||
def batch_as_completed(
|
||||
self,
|
||||
@ -5399,7 +5397,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
|
||||
) -> Iterator[tuple[int, Union[Output, Exception]]]:
|
||||
if isinstance(config, Sequence):
|
||||
configs = cast(
|
||||
List[RunnableConfig],
|
||||
@ -5431,7 +5429,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
*,
|
||||
return_exceptions: Literal[False] = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Tuple[int, Output]]: ...
|
||||
) -> AsyncIterator[tuple[int, Output]]: ...
|
||||
|
||||
@overload
|
||||
def abatch_as_completed(
|
||||
@ -5441,7 +5439,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
*,
|
||||
return_exceptions: Literal[True],
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: ...
|
||||
) -> AsyncIterator[tuple[int, Union[Output, Exception]]]: ...
|
||||
|
||||
async def abatch_as_completed(
|
||||
self,
|
||||
@ -5450,7 +5448,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
|
||||
) -> AsyncIterator[tuple[int, Union[Output, Exception]]]:
|
||||
if isinstance(config, Sequence):
|
||||
configs = cast(
|
||||
List[RunnableConfig],
|
||||
@ -5590,7 +5588,7 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@ -5678,8 +5676,8 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
|
||||
|
||||
def with_types(
|
||||
self,
|
||||
input_type: Optional[Union[Type[Input], BaseModel]] = None,
|
||||
output_type: Optional[Union[Type[Output], BaseModel]] = None,
|
||||
input_type: Optional[Union[type[Input], BaseModel]] = None,
|
||||
output_type: Optional[Union[type[Output], BaseModel]] = None,
|
||||
) -> Runnable[Input, Output]:
|
||||
return self.__class__(
|
||||
bound=self.bound,
|
||||
|
@ -12,7 +12,6 @@ from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
@ -56,13 +55,13 @@ class EmptyDict(TypedDict, total=False):
|
||||
class RunnableConfig(TypedDict, total=False):
|
||||
"""Configuration for a Runnable."""
|
||||
|
||||
tags: List[str]
|
||||
tags: list[str]
|
||||
"""
|
||||
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
You can use these to filter calls.
|
||||
"""
|
||||
|
||||
metadata: Dict[str, Any]
|
||||
metadata: dict[str, Any]
|
||||
"""
|
||||
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
Keys should be strings, values should be JSON-serializable.
|
||||
@ -90,7 +89,7 @@ class RunnableConfig(TypedDict, total=False):
|
||||
Maximum number of times a call can recurse. If not provided, defaults to 25.
|
||||
"""
|
||||
|
||||
configurable: Dict[str, Any]
|
||||
configurable: dict[str, Any]
|
||||
"""
|
||||
Runtime values for attributes previously made configurable on this Runnable,
|
||||
or sub-Runnables, through .configurable_fields() or .configurable_alternatives().
|
||||
@ -205,7 +204,7 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||
|
||||
def get_config_list(
|
||||
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]], length: int
|
||||
) -> List[RunnableConfig]:
|
||||
) -> list[RunnableConfig]:
|
||||
"""Get a list of configs from a single config or a list of configs.
|
||||
|
||||
It is useful for subclasses overriding batch() or abatch().
|
||||
@ -255,7 +254,7 @@ def patch_config(
|
||||
recursion_limit: Optional[int] = None,
|
||||
max_concurrency: Optional[int] = None,
|
||||
run_name: Optional[str] = None,
|
||||
configurable: Optional[Dict[str, Any]] = None,
|
||||
configurable: Optional[dict[str, Any]] = None,
|
||||
) -> RunnableConfig:
|
||||
"""Patch a config with new values.
|
||||
|
||||
|
@ -8,12 +8,10 @@ from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
@ -69,27 +67,27 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
def InputType(self) -> type[Input]:
|
||||
return self.default.InputType
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
def OutputType(self) -> type[Output]:
|
||||
return self.default.OutputType
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
runnable, config = self.prepare(config)
|
||||
return runnable.get_input_schema(config)
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
runnable, config = self.prepare(config)
|
||||
return runnable.get_output_schema(config)
|
||||
|
||||
@ -109,7 +107,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
|
||||
def prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
||||
) -> tuple[Runnable[Input, Output], RunnableConfig]:
|
||||
"""Prepare the Runnable for invocation.
|
||||
|
||||
Args:
|
||||
@ -127,7 +125,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
@abstractmethod
|
||||
def _prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Tuple[Runnable[Input, Output], RunnableConfig]: ...
|
||||
) -> tuple[Runnable[Input, Output], RunnableConfig]: ...
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
@ -143,12 +141,12 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[Input],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
configs = get_config_list(config, len(inputs))
|
||||
prepared = [self.prepare(c) for c in configs]
|
||||
|
||||
@ -164,7 +162,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
return []
|
||||
|
||||
def invoke(
|
||||
prepared: Tuple[Runnable[Input, Output], RunnableConfig],
|
||||
prepared: tuple[Runnable[Input, Output], RunnableConfig],
|
||||
input: Input,
|
||||
) -> Union[Output, Exception]:
|
||||
bound, config = prepared
|
||||
@ -185,12 +183,12 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[Input],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
configs = get_config_list(config, len(inputs))
|
||||
prepared = [self.prepare(c) for c in configs]
|
||||
|
||||
@ -206,7 +204,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
return []
|
||||
|
||||
async def ainvoke(
|
||||
prepared: Tuple[Runnable[Input, Output], RunnableConfig],
|
||||
prepared: tuple[Runnable[Input, Output], RunnableConfig],
|
||||
input: Input,
|
||||
) -> Union[Output, Exception]:
|
||||
bound, config = prepared
|
||||
@ -362,15 +360,15 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
)
|
||||
"""
|
||||
|
||||
fields: Dict[str, AnyConfigurableField]
|
||||
fields: dict[str, AnyConfigurableField]
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
"""Get the configuration specs for the RunnableConfigurableFields.
|
||||
|
||||
Returns:
|
||||
@ -412,7 +410,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
|
||||
def _prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
||||
) -> tuple[Runnable[Input, Output], RunnableConfig]:
|
||||
config = ensure_config(config)
|
||||
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
|
||||
configurable_fields = {
|
||||
@ -467,7 +465,7 @@ _enums_for_spec: WeakValueDictionary[
|
||||
Union[
|
||||
ConfigurableFieldSingleOption, ConfigurableFieldMultiOption, ConfigurableField
|
||||
],
|
||||
Type[StrEnum],
|
||||
type[StrEnum],
|
||||
] = WeakValueDictionary()
|
||||
|
||||
_enums_for_spec_lock = threading.Lock()
|
||||
@ -532,7 +530,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
which: ConfigurableField
|
||||
"""The ConfigurableField to use to choose between alternatives."""
|
||||
|
||||
alternatives: Dict[
|
||||
alternatives: dict[
|
||||
str,
|
||||
Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]],
|
||||
]
|
||||
@ -547,12 +545,12 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
the alternative named "gpt3" becomes "model==gpt3/temperature"."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
with _enums_for_spec_lock:
|
||||
if which_enum := _enums_for_spec.get(self.which):
|
||||
pass
|
||||
@ -612,7 +610,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
|
||||
def _prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
||||
) -> tuple[Runnable[Input, Output], RunnableConfig]:
|
||||
config = ensure_config(config)
|
||||
which = config.get("configurable", {}).get(self.which.id, self.default_key)
|
||||
# remap configurable keys for the chosen alternative
|
||||
|
@ -8,14 +8,10 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
overload,
|
||||
@ -106,8 +102,8 @@ class Node(NamedTuple):
|
||||
|
||||
id: str
|
||||
name: str
|
||||
data: Union[Type[BaseModel], RunnableType]
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
data: Union[type[BaseModel], RunnableType]
|
||||
metadata: Optional[dict[str, Any]]
|
||||
|
||||
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
|
||||
"""Return a copy of the node with optional new id and name.
|
||||
@ -179,7 +175,7 @@ class MermaidDrawMethod(Enum):
|
||||
API = "api" # Uses Mermaid.INK API to render the graph
|
||||
|
||||
|
||||
def node_data_str(id: str, data: Union[Type[BaseModel], RunnableType]) -> str:
|
||||
def node_data_str(id: str, data: Union[type[BaseModel], RunnableType]) -> str:
|
||||
"""Convert the data of a node to a string.
|
||||
|
||||
Args:
|
||||
@ -202,7 +198,7 @@ def node_data_str(id: str, data: Union[Type[BaseModel], RunnableType]) -> str:
|
||||
|
||||
def node_data_json(
|
||||
node: Node, *, with_schemas: bool = False
|
||||
) -> Dict[str, Union[str, Dict[str, Any]]]:
|
||||
) -> dict[str, Union[str, dict[str, Any]]]:
|
||||
"""Convert the data of a node to a JSON-serializable format.
|
||||
|
||||
Args:
|
||||
@ -217,7 +213,7 @@ def node_data_json(
|
||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
|
||||
if isinstance(node.data, RunnableSerializable):
|
||||
json: Dict[str, Any] = {
|
||||
json: dict[str, Any] = {
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": node.data.lc_id(),
|
||||
@ -265,10 +261,10 @@ class Graph:
|
||||
edges: List of edges in the graph. Defaults to an empty list.
|
||||
"""
|
||||
|
||||
nodes: Dict[str, Node] = field(default_factory=dict)
|
||||
edges: List[Edge] = field(default_factory=list)
|
||||
nodes: dict[str, Node] = field(default_factory=dict)
|
||||
edges: list[Edge] = field(default_factory=list)
|
||||
|
||||
def to_json(self, *, with_schemas: bool = False) -> Dict[str, List[Dict[str, Any]]]:
|
||||
def to_json(self, *, with_schemas: bool = False) -> dict[str, list[dict[str, Any]]]:
|
||||
"""Convert the graph to a JSON-serializable format.
|
||||
|
||||
Args:
|
||||
@ -282,7 +278,7 @@ class Graph:
|
||||
node.id: i if is_uuid(node.id) else node.id
|
||||
for i, node in enumerate(self.nodes.values())
|
||||
}
|
||||
edges: List[Dict[str, Any]] = []
|
||||
edges: list[dict[str, Any]] = []
|
||||
for edge in self.edges:
|
||||
edge_dict = {
|
||||
"source": stable_node_ids[edge.source],
|
||||
@ -315,10 +311,10 @@ class Graph:
|
||||
|
||||
def add_node(
|
||||
self,
|
||||
data: Union[Type[BaseModel], RunnableType],
|
||||
data: Union[type[BaseModel], RunnableType],
|
||||
id: Optional[str] = None,
|
||||
*,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> Node:
|
||||
"""Add a node to the graph and return it.
|
||||
|
||||
@ -386,7 +382,7 @@ class Graph:
|
||||
|
||||
def extend(
|
||||
self, graph: Graph, *, prefix: str = ""
|
||||
) -> Tuple[Optional[Node], Optional[Node]]:
|
||||
) -> tuple[Optional[Node], Optional[Node]]:
|
||||
"""Add all nodes and edges from another graph.
|
||||
Note this doesn't check for duplicates, nor does it connect the graphs.
|
||||
|
||||
@ -622,7 +618,7 @@ def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
|
||||
If there is no such node, or there are multiple, return None.
|
||||
When drawing the graph, this node would be the origin."""
|
||||
targets = {edge.target for edge in graph.edges if edge.source not in exclude}
|
||||
found: List[Node] = []
|
||||
found: list[Node] = []
|
||||
for node in graph.nodes.values():
|
||||
if node.id not in exclude and node.id not in targets:
|
||||
found.append(node)
|
||||
@ -635,7 +631,7 @@ def _last_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
|
||||
If there is no such node, or there are multiple, return None.
|
||||
When drawing the graph, this node would be the destination."""
|
||||
sources = {edge.source for edge in graph.edges if edge.target not in exclude}
|
||||
found: List[Node] = []
|
||||
found: list[Node] = []
|
||||
for node in graph.nodes.values():
|
||||
if node.id not in exclude and node.id not in sources:
|
||||
found.append(node)
|
||||
|
@ -6,10 +6,8 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -238,7 +236,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
history_factory_config: Sequence[ConfigurableFieldSpec]
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@ -366,7 +364,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
self._history_chain = history_chain
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
"""Get the configuration specs for the RunnableWithMessageHistory."""
|
||||
return get_unique_config_specs(
|
||||
super().config_specs + list(self.history_factory_config)
|
||||
@ -374,10 +372,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
fields: Dict = {}
|
||||
fields: dict = {}
|
||||
if self.input_messages_key and self.history_messages_key:
|
||||
fields[self.input_messages_key] = (
|
||||
Union[str, BaseMessage, Sequence[BaseMessage]],
|
||||
@ -398,13 +396,13 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
)
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
def OutputType(self) -> type[Output]:
|
||||
output_type = self._history_chain.OutputType
|
||||
return output_type
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""Get a pydantic model that can be used to validate output to the Runnable.
|
||||
|
||||
Runnables that leverage the configurable_fields and configurable_alternatives
|
||||
@ -430,15 +428,15 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
module_name=self.__class__.__module__,
|
||||
)
|
||||
|
||||
def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
|
||||
def _is_not_async(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> bool:
|
||||
return False
|
||||
|
||||
async def _is_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
|
||||
async def _is_async(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> bool:
|
||||
return True
|
||||
|
||||
def _get_input_messages(
|
||||
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
|
||||
) -> List[BaseMessage]:
|
||||
) -> list[BaseMessage]:
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
# If dictionary, try to pluck the single key representing messages
|
||||
@ -481,7 +479,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
|
||||
def _get_output_messages(
|
||||
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
|
||||
) -> List[BaseMessage]:
|
||||
) -> list[BaseMessage]:
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
# If dictionary, try to pluck the single key representing messages
|
||||
@ -514,7 +512,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
f"Got {output_val}."
|
||||
)
|
||||
|
||||
def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]:
|
||||
def _enter_history(self, input: Any, config: RunnableConfig) -> list[BaseMessage]:
|
||||
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
||||
messages = hist.messages.copy()
|
||||
|
||||
@ -527,8 +525,8 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
return messages
|
||||
|
||||
async def _aenter_history(
|
||||
self, input: Dict[str, Any], config: RunnableConfig
|
||||
) -> List[BaseMessage]:
|
||||
self, input: dict[str, Any], config: RunnableConfig
|
||||
) -> list[BaseMessage]:
|
||||
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
||||
messages = (await hist.aget_messages()).copy()
|
||||
|
||||
@ -621,7 +619,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
return config
|
||||
|
||||
|
||||
def _get_parameter_names(callable_: GetSessionHistoryCallable) -> List[str]:
|
||||
def _get_parameter_names(callable_: GetSessionHistoryCallable) -> list[str]:
|
||||
"""Get the parameter names of the callable."""
|
||||
sig = inspect.signature(callable_)
|
||||
return list(sig.parameters.keys())
|
||||
|
@ -13,10 +13,8 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -144,7 +142,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
# {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20}
|
||||
"""
|
||||
|
||||
input_type: Optional[Type[Other]] = None
|
||||
input_type: Optional[type[Other]] = None
|
||||
|
||||
func: Optional[
|
||||
Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]]
|
||||
@ -180,7 +178,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
]
|
||||
] = None,
|
||||
*,
|
||||
input_type: Optional[Type[Other]] = None,
|
||||
input_type: Optional[type[Other]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if inspect.iscoroutinefunction(func):
|
||||
@ -194,7 +192,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@ -210,11 +208,11 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
def assign(
|
||||
cls,
|
||||
**kwargs: Union[
|
||||
Runnable[Dict[str, Any], Any],
|
||||
Callable[[Dict[str, Any]], Any],
|
||||
Runnable[dict[str, Any], Any],
|
||||
Callable[[dict[str, Any]], Any],
|
||||
Mapping[
|
||||
str,
|
||||
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
|
||||
Union[Runnable[dict[str, Any], Any], Callable[[dict[str, Any]], Any]],
|
||||
],
|
||||
],
|
||||
) -> RunnableAssign:
|
||||
@ -228,7 +226,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
A Runnable that merges the Dict input with the output produced by the
|
||||
mapping argument.
|
||||
"""
|
||||
return RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
|
||||
return RunnableAssign(RunnableParallel[dict[str, Any]](kwargs))
|
||||
|
||||
def invoke(
|
||||
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
@ -392,9 +390,9 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
# returns {'input': 5, 'add_step': {'added': 15}}
|
||||
"""
|
||||
|
||||
mapper: RunnableParallel[Dict[str, Any]]
|
||||
mapper: RunnableParallel[dict[str, Any]]
|
||||
|
||||
def __init__(self, mapper: RunnableParallel[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
def __init__(self, mapper: RunnableParallel[dict[str, Any]], **kwargs: Any) -> None:
|
||||
super().__init__(mapper=mapper, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
@classmethod
|
||||
@ -402,7 +400,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@ -418,7 +416,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
map_input_schema = self.mapper.get_input_schema(config)
|
||||
if not issubclass(map_input_schema, RootModel):
|
||||
# ie. it's a dict
|
||||
@ -428,7 +426,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
map_input_schema = self.mapper.get_input_schema(config)
|
||||
map_output_schema = self.mapper.get_output_schema(config)
|
||||
if not issubclass(map_input_schema, RootModel) and not issubclass(
|
||||
@ -453,7 +451,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
return super().get_output_schema(config)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
return self.mapper.config_specs
|
||||
|
||||
def get_graph(self, config: RunnableConfig | None = None) -> Graph:
|
||||
@ -470,11 +468,11 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
input: dict[str, Any],
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
assert isinstance(
|
||||
input, dict
|
||||
), "The input to RunnablePassthrough.assign() must be a dict."
|
||||
@ -490,19 +488,19 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
input: dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
return self._call_with_config(self._invoke, input, config, **kwargs)
|
||||
|
||||
async def _ainvoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
input: dict[str, Any],
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
assert isinstance(
|
||||
input, dict
|
||||
), "The input to RunnablePassthrough.assign() must be a dict."
|
||||
@ -518,19 +516,19 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
input: dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
|
||||
|
||||
def _transform(
|
||||
self,
|
||||
input: Iterator[Dict[str, Any]],
|
||||
input: Iterator[dict[str, Any]],
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
# collect mapper keys
|
||||
mapper_keys = set(self.mapper.steps__.keys())
|
||||
# create two streams, one for the map and one for the passthrough
|
||||
@ -572,21 +570,21 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Dict[str, Any]],
|
||||
input: Iterator[dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
yield from self._transform_stream_with_config(
|
||||
input, self._transform, config, **kwargs
|
||||
)
|
||||
|
||||
async def _atransform(
|
||||
self,
|
||||
input: AsyncIterator[Dict[str, Any]],
|
||||
input: AsyncIterator[dict[str, Any]],
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
# collect mapper keys
|
||||
mapper_keys = set(self.mapper.steps__.keys())
|
||||
# create two streams, one for the map and one for the passthrough
|
||||
@ -622,10 +620,10 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Dict[str, Any]],
|
||||
input: AsyncIterator[dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
async for chunk in self._atransform_stream_with_config(
|
||||
input, self._atransform, config, **kwargs
|
||||
):
|
||||
@ -633,19 +631,19 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
input: dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
return self.transform(iter([input]), config, **kwargs)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
input: dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
async def input_aiter() -> AsyncIterator[Dict[str, Any]]:
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
async def input_aiter() -> AsyncIterator[dict[str, Any]]:
|
||||
yield input
|
||||
|
||||
async for chunk in self.atransform(input_aiter(), config, **kwargs):
|
||||
@ -683,9 +681,9 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
print(output_data) # Output: {'name': 'John', 'age': 30}
|
||||
"""
|
||||
|
||||
keys: Union[str, List[str]]
|
||||
keys: Union[str, list[str]]
|
||||
|
||||
def __init__(self, keys: Union[str, List[str]], **kwargs: Any) -> None:
|
||||
def __init__(self, keys: Union[str, list[str]], **kwargs: Any) -> None:
|
||||
super().__init__(keys=keys, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
@classmethod
|
||||
@ -693,7 +691,7 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@ -707,7 +705,7 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
)
|
||||
return super().get_name(suffix, name=name)
|
||||
|
||||
def _pick(self, input: Dict[str, Any]) -> Any:
|
||||
def _pick(self, input: dict[str, Any]) -> Any:
|
||||
assert isinstance(
|
||||
input, dict
|
||||
), "The input to RunnablePassthrough.assign() must be a dict."
|
||||
@ -723,36 +721,36 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
input: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
return self._pick(input)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
input: dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
return self._call_with_config(self._invoke, input, config, **kwargs)
|
||||
|
||||
async def _ainvoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
input: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
return self._pick(input)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
input: dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
|
||||
|
||||
def _transform(
|
||||
self,
|
||||
input: Iterator[Dict[str, Any]],
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
input: Iterator[dict[str, Any]],
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
for chunk in input:
|
||||
picked = self._pick(chunk)
|
||||
if picked is not None:
|
||||
@ -760,18 +758,18 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Dict[str, Any]],
|
||||
input: Iterator[dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
yield from self._transform_stream_with_config(
|
||||
input, self._transform, config, **kwargs
|
||||
)
|
||||
|
||||
async def _atransform(
|
||||
self,
|
||||
input: AsyncIterator[Dict[str, Any]],
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
input: AsyncIterator[dict[str, Any]],
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
async for chunk in input:
|
||||
picked = self._pick(chunk)
|
||||
if picked is not None:
|
||||
@ -779,10 +777,10 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Dict[str, Any]],
|
||||
input: AsyncIterator[dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
async for chunk in self._atransform_stream_with_config(
|
||||
input, self._atransform, config, **kwargs
|
||||
):
|
||||
@ -790,19 +788,19 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
input: dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
return self.transform(iter([input]), config, **kwargs)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
input: dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
async def input_aiter() -> AsyncIterator[Dict[str, Any]]:
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
async def input_aiter() -> AsyncIterator[dict[str, Any]]:
|
||||
yield input
|
||||
|
||||
async for chunk in self.atransform(input_aiter(), config, **kwargs):
|
||||
|
@ -71,7 +71,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
runnables: Mapping[str, Runnable[Any, Output]]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec for step in self.runnables.values() for spec in step.config_specs
|
||||
)
|
||||
@ -94,7 +94,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@ -125,12 +125,12 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[RouterInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[RouterInput],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
@ -160,12 +160,12 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[RouterInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[RouterInput],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Literal, Sequence, Union
|
||||
from typing import Any, Literal, Sequence, Union
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
@ -110,7 +110,7 @@ class BaseStreamEvent(TypedDict):
|
||||
Each child Runnable that gets invoked as part of the execution of a parent Runnable
|
||||
is assigned its own unique ID.
|
||||
"""
|
||||
tags: NotRequired[List[str]]
|
||||
tags: NotRequired[list[str]]
|
||||
"""Tags associated with the Runnable that generated this event.
|
||||
|
||||
Tags are always inherited from parent Runnables.
|
||||
@ -118,7 +118,7 @@ class BaseStreamEvent(TypedDict):
|
||||
Tags can either be bound to a Runnable using `.with_config({"tags": ["hello"]})`
|
||||
or passed at run time using `.astream_events(..., {"tags": ["hello"]})`.
|
||||
"""
|
||||
metadata: NotRequired[Dict[str, Any]]
|
||||
metadata: NotRequired[dict[str, Any]]
|
||||
"""Metadata associated with the Runnable that generated this event.
|
||||
|
||||
Metadata can either be bound to a Runnable using
|
||||
|
@ -18,13 +18,11 @@ from typing import (
|
||||
Coroutine,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Set,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
@ -126,7 +124,7 @@ def asyncio_accepts_context() -> bool:
|
||||
class IsLocalDict(ast.NodeVisitor):
|
||||
"""Check if a name is a local dict."""
|
||||
|
||||
def __init__(self, name: str, keys: Set[str]) -> None:
|
||||
def __init__(self, name: str, keys: set[str]) -> None:
|
||||
"""Initialize the visitor.
|
||||
|
||||
Args:
|
||||
@ -181,7 +179,7 @@ class IsFunctionArgDict(ast.NodeVisitor):
|
||||
"""Check if the first argument of a function is a dict."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.keys: Set[str] = set()
|
||||
self.keys: set[str] = set()
|
||||
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
"""Visit a lambda function.
|
||||
@ -230,8 +228,8 @@ class NonLocals(ast.NodeVisitor):
|
||||
"""Get nonlocal variables accessed."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.loads: Set[str] = set()
|
||||
self.stores: Set[str] = set()
|
||||
self.loads: set[str] = set()
|
||||
self.stores: set[str] = set()
|
||||
|
||||
def visit_Name(self, node: ast.Name) -> Any:
|
||||
"""Visit a name node.
|
||||
@ -271,7 +269,7 @@ class FunctionNonLocals(ast.NodeVisitor):
|
||||
"""Get the nonlocal variables accessed of a function."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.nonlocals: Set[str] = set()
|
||||
self.nonlocals: set[str] = set()
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||
"""Visit a function definition.
|
||||
@ -335,7 +333,7 @@ class GetLambdaSource(ast.NodeVisitor):
|
||||
self.source = ast.unparse(node)
|
||||
|
||||
|
||||
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
||||
def get_function_first_arg_dict_keys(func: Callable) -> Optional[list[str]]:
|
||||
"""Get the keys of the first argument of a function if it is a dict.
|
||||
|
||||
Args:
|
||||
@ -378,7 +376,7 @@ def get_lambda_source(func: Callable) -> Optional[str]:
|
||||
return name
|
||||
|
||||
|
||||
def get_function_nonlocals(func: Callable) -> List[Any]:
|
||||
def get_function_nonlocals(func: Callable) -> list[Any]:
|
||||
"""Get the nonlocal variables accessed by a function.
|
||||
|
||||
Args:
|
||||
@ -392,7 +390,7 @@ def get_function_nonlocals(func: Callable) -> List[Any]:
|
||||
tree = ast.parse(textwrap.dedent(code))
|
||||
visitor = FunctionNonLocals()
|
||||
visitor.visit(tree)
|
||||
values: List[Any] = []
|
||||
values: list[Any] = []
|
||||
closure = inspect.getclosurevars(func)
|
||||
candidates = {**closure.globals, **closure.nonlocals}
|
||||
for k, v in candidates.items():
|
||||
@ -608,12 +606,12 @@ class ConfigurableFieldSpec(NamedTuple):
|
||||
description: Optional[str] = None
|
||||
default: Any = None
|
||||
is_shared: bool = False
|
||||
dependencies: Optional[List[str]] = None
|
||||
dependencies: Optional[list[str]] = None
|
||||
|
||||
|
||||
def get_unique_config_specs(
|
||||
specs: Iterable[ConfigurableFieldSpec],
|
||||
) -> List[ConfigurableFieldSpec]:
|
||||
) -> list[ConfigurableFieldSpec]:
|
||||
"""Get the unique config specs from a sequence of config specs.
|
||||
|
||||
Args:
|
||||
@ -628,7 +626,7 @@ def get_unique_config_specs(
|
||||
grouped = groupby(
|
||||
sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id
|
||||
)
|
||||
unique: List[ConfigurableFieldSpec] = []
|
||||
unique: list[ConfigurableFieldSpec] = []
|
||||
for id, dupes in grouped:
|
||||
first = next(dupes)
|
||||
others = list(dupes)
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -142,10 +142,10 @@ class Operation(FilterDirective):
|
||||
"""
|
||||
|
||||
operator: Operator
|
||||
arguments: List[FilterDirective]
|
||||
arguments: list[FilterDirective]
|
||||
|
||||
def __init__(
|
||||
self, operator: Operator, arguments: List[FilterDirective], **kwargs: Any
|
||||
self, operator: Operator, arguments: list[FilterDirective], **kwargs: Any
|
||||
) -> None:
|
||||
# super exists from BaseModel
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
|
@ -13,12 +13,9 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
@ -78,11 +75,11 @@ class SchemaAnnotationError(TypeError):
|
||||
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
||||
|
||||
|
||||
def _is_annotated_type(typ: Type[Any]) -> bool:
|
||||
def _is_annotated_type(typ: type[Any]) -> bool:
|
||||
return get_origin(typ) is Annotated
|
||||
|
||||
|
||||
def _get_annotation_description(arg_type: Type) -> str | None:
|
||||
def _get_annotation_description(arg_type: type) -> str | None:
|
||||
if _is_annotated_type(arg_type):
|
||||
annotated_args = get_args(arg_type)
|
||||
for annotation in annotated_args[1:]:
|
||||
@ -92,7 +89,7 @@ def _get_annotation_description(arg_type: Type) -> str | None:
|
||||
|
||||
|
||||
def _get_filtered_args(
|
||||
inferred_model: Type[BaseModel],
|
||||
inferred_model: type[BaseModel],
|
||||
func: Callable,
|
||||
*,
|
||||
filter_args: Sequence[str],
|
||||
@ -112,7 +109,7 @@ def _get_filtered_args(
|
||||
|
||||
def _parse_python_function_docstring(
|
||||
function: Callable, annotations: dict, error_on_invalid_docstring: bool = False
|
||||
) -> Tuple[str, dict]:
|
||||
) -> tuple[str, dict]:
|
||||
"""Parse the function and argument descriptions from the docstring of a function.
|
||||
|
||||
Assumes the function docstring follows Google Python style guide.
|
||||
@ -141,7 +138,7 @@ def _infer_arg_descriptions(
|
||||
*,
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
) -> Tuple[str, dict]:
|
||||
) -> tuple[str, dict]:
|
||||
"""Infer argument descriptions from a function's docstring."""
|
||||
if hasattr(inspect, "get_annotations"):
|
||||
# This is for python < 3.10
|
||||
@ -218,7 +215,7 @@ def create_schema_from_function(
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
include_injected: bool = True,
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""Create a pydantic schema from a function's signature.
|
||||
|
||||
Args:
|
||||
@ -273,7 +270,7 @@ def create_schema_from_function(
|
||||
filter_args_ = filter_args
|
||||
else:
|
||||
# Handle classmethods and instance methods
|
||||
existing_params: List[str] = list(sig.parameters.keys())
|
||||
existing_params: list[str] = list(sig.parameters.keys())
|
||||
if existing_params and existing_params[0] in ("self", "cls") and in_class:
|
||||
filter_args_ = [existing_params[0]] + list(FILTERED_ARGS)
|
||||
else:
|
||||
@ -395,13 +392,13 @@ class ChildTool(BaseTool):
|
||||
description="Callback manager to add to the run trace.",
|
||||
)
|
||||
)
|
||||
tags: Optional[List[str]] = None
|
||||
tags: Optional[list[str]] = None
|
||||
"""Optional list of tags associated with the tool. Defaults to None.
|
||||
These tags will be associated with each call to this tool,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a tool with its use case.
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
"""Optional metadata associated with the tool. Defaults to None.
|
||||
This metadata will be associated with each call to this tool,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
@ -451,7 +448,7 @@ class ChildTool(BaseTool):
|
||||
return self.get_input_schema().model_json_schema()["properties"]
|
||||
|
||||
@property
|
||||
def tool_call_schema(self) -> Type[BaseModel]:
|
||||
def tool_call_schema(self) -> type[BaseModel]:
|
||||
full_schema = self.get_input_schema()
|
||||
fields = []
|
||||
for name, type_ in _get_all_basemodel_annotations(full_schema).items():
|
||||
@ -465,7 +462,7 @@ class ChildTool(BaseTool):
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""The tool's input schema.
|
||||
|
||||
Args:
|
||||
@ -481,7 +478,7 @@ class ChildTool(BaseTool):
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Union[str, Dict, ToolCall],
|
||||
input: Union[str, dict, ToolCall],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@ -490,7 +487,7 @@ class ChildTool(BaseTool):
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, Dict, ToolCall],
|
||||
input: Union[str, dict, ToolCall],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@ -499,7 +496,7 @@ class ChildTool(BaseTool):
|
||||
|
||||
# --- Tool ---
|
||||
|
||||
def _parse_input(self, tool_input: Union[str, Dict]) -> Union[str, Dict[str, Any]]:
|
||||
def _parse_input(self, tool_input: Union[str, dict]) -> Union[str, dict[str, Any]]:
|
||||
"""Convert tool input to a pydantic model.
|
||||
|
||||
Args:
|
||||
@ -536,7 +533,7 @@ class ChildTool(BaseTool):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: Dict) -> Any:
|
||||
def raise_deprecation(cls, values: dict) -> Any:
|
||||
"""Raise deprecation warning if callback_manager is used.
|
||||
|
||||
Args:
|
||||
@ -574,7 +571,7 @@ class ChildTool(BaseTool):
|
||||
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
|
||||
return await run_in_executor(None, self._run, *args, **kwargs)
|
||||
|
||||
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
||||
def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]:
|
||||
tool_input = self._parse_input(tool_input)
|
||||
# For backwards compatibility, if run_input is a string,
|
||||
# pass as a positional argument.
|
||||
@ -585,14 +582,14 @@ class ChildTool(BaseTool):
|
||||
|
||||
def run(
|
||||
self,
|
||||
tool_input: Union[str, Dict[str, Any]],
|
||||
tool_input: Union[str, dict[str, Any]],
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
@ -696,14 +693,14 @@ class ChildTool(BaseTool):
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
tool_input: Union[str, dict],
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
@ -866,7 +863,7 @@ def _prep_run_args(
|
||||
input: Union[str, dict, ToolCall],
|
||||
config: Optional[RunnableConfig],
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Union[str, Dict], Dict]:
|
||||
) -> tuple[Union[str, dict], dict]:
|
||||
config = ensure_config(config)
|
||||
if _is_tool_call(input):
|
||||
tool_call_id: Optional[str] = cast(ToolCall, input)["id"]
|
||||
@ -933,7 +930,7 @@ def _stringify(content: Any) -> str:
|
||||
return str(content)
|
||||
|
||||
|
||||
def _get_type_hints(func: Callable) -> Optional[Dict[str, Type]]:
|
||||
def _get_type_hints(func: Callable) -> Optional[dict[str, type]]:
|
||||
if isinstance(func, functools.partial):
|
||||
func = func.func
|
||||
try:
|
||||
@ -956,7 +953,7 @@ class InjectedToolArg:
|
||||
"""Annotation for a Tool arg that is **not** meant to be generated by a model."""
|
||||
|
||||
|
||||
def _is_injected_arg_type(type_: Type) -> bool:
|
||||
def _is_injected_arg_type(type_: type) -> bool:
|
||||
return any(
|
||||
isinstance(arg, InjectedToolArg)
|
||||
or (isinstance(arg, type) and issubclass(arg, InjectedToolArg))
|
||||
@ -966,10 +963,10 @@ def _is_injected_arg_type(type_: Type) -> bool:
|
||||
|
||||
def _get_all_basemodel_annotations(
|
||||
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
|
||||
) -> Dict[str, Type]:
|
||||
) -> dict[str, type]:
|
||||
# cls has no subscript: cls = FooBar
|
||||
if isinstance(cls, type):
|
||||
annotations: Dict[str, Type] = {}
|
||||
annotations: dict[str, type] = {}
|
||||
for name, param in inspect.signature(cls).parameters.items():
|
||||
# Exclude hidden init args added by pydantic Config. For example if
|
||||
# BaseModel(extra="allow") then "extra_data" will part of init sig.
|
||||
@ -979,7 +976,7 @@ def _get_all_basemodel_annotations(
|
||||
) and name not in fields:
|
||||
continue
|
||||
annotations[name] = param.annotation
|
||||
orig_bases: Tuple = getattr(cls, "__orig_bases__", tuple())
|
||||
orig_bases: tuple = getattr(cls, "__orig_bases__", tuple())
|
||||
# cls has subscript: cls = FooBar[int]
|
||||
else:
|
||||
annotations = _get_all_basemodel_annotations(
|
||||
@ -1011,7 +1008,7 @@ def _get_all_basemodel_annotations(
|
||||
# parent_origin = Baz,
|
||||
# generic_type_vars = (type vars in Baz)
|
||||
# generic_map = {type var in Baz: str}
|
||||
generic_type_vars: Tuple = getattr(parent_origin, "__parameters__", tuple())
|
||||
generic_type_vars: tuple = getattr(parent_origin, "__parameters__", tuple())
|
||||
generic_map = {
|
||||
type_var: t for type_var, t in zip(generic_type_vars, get_args(parent))
|
||||
}
|
||||
@ -1027,10 +1024,10 @@ def _get_all_basemodel_annotations(
|
||||
|
||||
|
||||
def _replace_type_vars(
|
||||
type_: Type,
|
||||
generic_map: Optional[Dict[TypeVar, Type]] = None,
|
||||
type_: type,
|
||||
generic_map: Optional[dict[TypeVar, type]] = None,
|
||||
default_to_bound: bool = True,
|
||||
) -> Type:
|
||||
) -> type:
|
||||
generic_map = generic_map or {}
|
||||
if isinstance(type_, TypeVar):
|
||||
if type_ in generic_map:
|
||||
@ -1043,7 +1040,7 @@ def _replace_type_vars(
|
||||
new_args = tuple(
|
||||
_replace_type_vars(arg, generic_map, default_to_bound) for arg in args
|
||||
)
|
||||
return _py_38_safe_origin(origin)[new_args]
|
||||
return _py_38_safe_origin(origin)[new_args] # type: ignore[index]
|
||||
else:
|
||||
return type_
|
||||
|
||||
@ -1052,5 +1049,5 @@ class BaseToolkit(BaseModel, ABC):
|
||||
"""Base Toolkit representing a collection of related tools."""
|
||||
|
||||
@abstractmethod
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
def get_tools(self) -> list[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
@ -8,7 +8,7 @@ from langchain_core.tools.base import BaseTool
|
||||
ToolsRenderer = Callable[[List[BaseTool]], str]
|
||||
|
||||
|
||||
def render_text_description(tools: List[BaseTool]) -> str:
|
||||
def render_text_description(tools: list[BaseTool]) -> str:
|
||||
"""Render the tool name and description in plain text.
|
||||
|
||||
Args:
|
||||
@ -36,7 +36,7 @@ def render_text_description(tools: List[BaseTool]) -> str:
|
||||
return "\n".join(descriptions)
|
||||
|
||||
|
||||
def render_text_description_and_args(tools: List[BaseTool]) -> str:
|
||||
def render_text_description_and_args(tools: list[BaseTool]) -> str:
|
||||
"""Render the tool name, description, and args in plain text.
|
||||
|
||||
Args:
|
||||
|
@ -5,10 +5,7 @@ from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -40,7 +37,7 @@ class Tool(BaseTool):
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, Dict, ToolCall],
|
||||
input: Union[str, dict, ToolCall],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@ -65,7 +62,7 @@ class Tool(BaseTool):
|
||||
# assume it takes a single string input.
|
||||
return {"tool_input": {"type": "string"}}
|
||||
|
||||
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
||||
def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]:
|
||||
"""Convert tool input to pydantic model."""
|
||||
args, kwargs = super()._to_args_and_kwargs(tool_input)
|
||||
# For backwards compatibility. The tool must be run with a single input
|
||||
@ -131,7 +128,7 @@ class Tool(BaseTool):
|
||||
name: str, # We keep these required to support backwards compatibility
|
||||
description: str,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
args_schema: Optional[type[BaseModel]] = None,
|
||||
coroutine: Optional[
|
||||
Callable[..., Awaitable[Any]]
|
||||
] = None, # This is last for compatibility, but should be after func
|
||||
|
@ -6,11 +6,8 @@ from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -50,7 +47,7 @@ class StructuredTool(BaseTool):
|
||||
# TODO: Is this needed?
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, Dict, ToolCall],
|
||||
input: Union[str, dict, ToolCall],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@ -112,7 +109,7 @@ class StructuredTool(BaseTool):
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
args_schema: Optional[type[BaseModel]] = None,
|
||||
infer_schema: bool = True,
|
||||
*,
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
@ -209,7 +206,7 @@ class StructuredTool(BaseTool):
|
||||
)
|
||||
|
||||
|
||||
def _filter_schema_args(func: Callable) -> List[str]:
|
||||
def _filter_schema_args(func: Callable) -> list[str]:
|
||||
filter_args = list(FILTERED_ARGS)
|
||||
if config_param := _get_runnable_config_param(func):
|
||||
filter_args.append(config_param)
|
||||
|
@ -8,8 +8,6 @@ from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
@ -52,13 +50,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
@ -93,13 +91,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
@ -238,13 +236,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_type: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
@ -282,10 +280,10 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
outputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""End a trace for a chain run.
|
||||
@ -313,7 +311,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
@ -340,15 +338,15 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for a tool run.
|
||||
@ -429,13 +427,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
@ -556,13 +554,13 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@ -585,13 +583,13 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
llm_run = self._create_llm_run(
|
||||
@ -642,7 +640,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
llm_run = self._complete_llm_run(
|
||||
@ -658,7 +656,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
llm_run = self._errored_llm_run(
|
||||
@ -670,13 +668,13 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_type: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
@ -697,10 +695,10 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
|
||||
async def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
outputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
chain_run = self._complete_chain_run(
|
||||
@ -716,7 +714,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -731,15 +729,15 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
tool_run = self._create_tool_run(
|
||||
@ -776,7 +774,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
tool_run = self._errored_tool_run(
|
||||
@ -788,13 +786,13 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -819,7 +817,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
retrieval_run = self._errored_retrieval_run(
|
||||
@ -839,7 +837,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
retrieval_run = self._complete_retrieval_run(
|
||||
|
@ -6,10 +6,7 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -53,7 +50,7 @@ def tracing_v2_enabled(
|
||||
project_name: Optional[str] = None,
|
||||
*,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
client: Optional[LangSmithClient] = None,
|
||||
) -> Generator[LangChainTracer, None, None]:
|
||||
"""Instruct LangChain to log all runs in context to LangSmith.
|
||||
@ -169,11 +166,11 @@ def _get_tracer_project() -> str:
|
||||
)
|
||||
|
||||
|
||||
_configure_hooks: List[
|
||||
Tuple[
|
||||
_configure_hooks: list[
|
||||
tuple[
|
||||
ContextVar[Optional[BaseCallbackHandler]],
|
||||
bool,
|
||||
Optional[Type[BaseCallbackHandler]],
|
||||
Optional[type[BaseCallbackHandler]],
|
||||
Optional[str],
|
||||
]
|
||||
] = []
|
||||
@ -182,7 +179,7 @@ _configure_hooks: List[
|
||||
def register_configure_hook(
|
||||
context_var: ContextVar[Optional[Any]],
|
||||
inheritable: bool,
|
||||
handle_class: Optional[Type[BaseCallbackHandler]] = None,
|
||||
handle_class: Optional[type[BaseCallbackHandler]] = None,
|
||||
env_var: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Register a configure hook.
|
||||
|
@ -11,13 +11,9 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Coroutine,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -81,9 +77,9 @@ class _TracerCore(ABC):
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._schema_format = _schema_format # For internal use only API will change.
|
||||
self.run_map: Dict[str, Run] = {}
|
||||
self.run_map: dict[str, Run] = {}
|
||||
"""Map of run ID to run. Cleared on run end."""
|
||||
self.order_map: Dict[UUID, Tuple[UUID, str]] = {}
|
||||
self.order_map: dict[UUID, tuple[UUID, str]] = {}
|
||||
"""Map of run ID to (trace_id, dotted_order). Cleared when tracer GCed."""
|
||||
|
||||
@abstractmethod
|
||||
@ -137,7 +133,7 @@ class _TracerCore(ABC):
|
||||
self.run_map[str(run.id)] = run
|
||||
|
||||
def _get_run(
|
||||
self, run_id: UUID, run_type: Union[str, Set[str], None] = None
|
||||
self, run_id: UUID, run_type: Union[str, set[str], None] = None
|
||||
) -> Run:
|
||||
try:
|
||||
run = self.run_map[str(run_id)]
|
||||
@ -145,7 +141,7 @@ class _TracerCore(ABC):
|
||||
raise TracerException(f"No indexed run ID {run_id}.") from exc
|
||||
|
||||
if isinstance(run_type, str):
|
||||
run_types: Union[Set[str], None] = {run_type}
|
||||
run_types: Union[set[str], None] = {run_type}
|
||||
else:
|
||||
run_types = run_type
|
||||
if run_types is not None and run.run_type not in run_types:
|
||||
@ -157,12 +153,12 @@ class _TracerCore(ABC):
|
||||
|
||||
def _create_chat_model_run(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
@ -200,12 +196,12 @@ class _TracerCore(ABC):
|
||||
|
||||
def _create_llm_run(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
@ -239,7 +235,7 @@ class _TracerCore(ABC):
|
||||
Append token event to LLM run and return the run.
|
||||
"""
|
||||
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
|
||||
event_kwargs: Dict[str, Any] = {"token": token}
|
||||
event_kwargs: dict[str, Any] = {"token": token}
|
||||
if chunk:
|
||||
event_kwargs["chunk"] = chunk
|
||||
llm_run.events.append(
|
||||
@ -258,7 +254,7 @@ class _TracerCore(ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
llm_run = self._get_run(run_id)
|
||||
retry_d: Dict[str, Any] = {
|
||||
retry_d: dict[str, Any] = {
|
||||
"slept": retry_state.idle_for,
|
||||
"attempt": retry_state.attempt_number,
|
||||
}
|
||||
@ -306,12 +302,12 @@ class _TracerCore(ABC):
|
||||
|
||||
def _create_chain_run(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_type: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
@ -358,9 +354,9 @@ class _TracerCore(ABC):
|
||||
|
||||
def _complete_chain_run(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
outputs: dict[str, Any],
|
||||
run_id: UUID,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Update a chain run with outputs and end time."""
|
||||
@ -375,7 +371,7 @@ class _TracerCore(ABC):
|
||||
def _errored_chain_run(
|
||||
self,
|
||||
error: BaseException,
|
||||
inputs: Optional[Dict[str, Any]],
|
||||
inputs: Optional[dict[str, Any]],
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
@ -389,14 +385,14 @@ class _TracerCore(ABC):
|
||||
|
||||
def _create_tool_run(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Create a tool run."""
|
||||
@ -428,7 +424,7 @@ class _TracerCore(ABC):
|
||||
|
||||
def _complete_tool_run(
|
||||
self,
|
||||
output: Dict[str, Any],
|
||||
output: dict[str, Any],
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
@ -454,12 +450,12 @@ class _TracerCore(ABC):
|
||||
|
||||
def _create_retrieval_run(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
query: str,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
|
@ -6,7 +6,7 @@ import logging
|
||||
import threading
|
||||
import weakref
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
||||
from typing import Any, List, Optional, Sequence, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
import langsmith
|
||||
@ -96,7 +96,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
self.futures: weakref.WeakSet[Future] = weakref.WeakSet()
|
||||
self.skip_unfinished = skip_unfinished
|
||||
self.project_name = project_name
|
||||
self.logged_eval_results: Dict[Tuple[str, str], List[EvaluationResult]] = {}
|
||||
self.logged_eval_results: dict[tuple[str, str], list[EvaluationResult]] = {}
|
||||
self.lock = threading.Lock()
|
||||
global _TRACERS
|
||||
_TRACERS.add(self)
|
||||
@ -152,7 +152,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
def _select_eval_results(
|
||||
self,
|
||||
results: Union[EvaluationResult, EvaluationResults],
|
||||
) -> List[EvaluationResult]:
|
||||
) -> list[EvaluationResult]:
|
||||
if isinstance(results, EvaluationResult):
|
||||
results_ = [results]
|
||||
elif isinstance(results, dict) and "results" in results:
|
||||
@ -169,10 +169,10 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
evaluator_response: Union[EvaluationResult, EvaluationResults],
|
||||
run: Run,
|
||||
source_run_id: Optional[UUID] = None,
|
||||
) -> List[EvaluationResult]:
|
||||
) -> list[EvaluationResult]:
|
||||
results = self._select_eval_results(evaluator_response)
|
||||
for res in results:
|
||||
source_info_: Dict[str, Any] = {}
|
||||
source_info_: dict[str, Any] = {}
|
||||
if res.evaluator_info:
|
||||
source_info_ = {**res.evaluator_info, **source_info_}
|
||||
run_id_ = getattr(res, "target_run_id", None)
|
||||
|
@ -8,7 +8,6 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
@ -66,14 +65,14 @@ class RunInfo(TypedDict):
|
||||
"""
|
||||
|
||||
name: str
|
||||
tags: List[str]
|
||||
metadata: Dict[str, Any]
|
||||
tags: list[str]
|
||||
metadata: dict[str, Any]
|
||||
run_type: str
|
||||
inputs: NotRequired[Any]
|
||||
parent_run_id: Optional[UUID]
|
||||
|
||||
|
||||
def _assign_name(name: Optional[str], serialized: Optional[Dict[str, Any]]) -> str:
|
||||
def _assign_name(name: Optional[str], serialized: Optional[dict[str, Any]]) -> str:
|
||||
"""Assign a name to a run."""
|
||||
if name is not None:
|
||||
return name
|
||||
@ -107,15 +106,15 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
# Map of run ID to run info.
|
||||
# the entry corresponding to a given run id is cleaned
|
||||
# up when each corresponding run ends.
|
||||
self.run_map: Dict[UUID, RunInfo] = {}
|
||||
self.run_map: dict[UUID, RunInfo] = {}
|
||||
# The callback event that corresponds to the end of a parent run
|
||||
# may be invoked BEFORE the callback event that corresponds to the end
|
||||
# of a child run, which results in clean up of run_map.
|
||||
# So we keep track of the mapping between children and parent run IDs
|
||||
# in a separate container. This container is GCed when the tracer is GCed.
|
||||
self.parent_map: Dict[UUID, Optional[UUID]] = {}
|
||||
self.parent_map: dict[UUID, Optional[UUID]] = {}
|
||||
|
||||
self.is_tapped: Dict[UUID, Any] = {}
|
||||
self.is_tapped: dict[UUID, Any] = {}
|
||||
|
||||
# Filter which events will be sent over the queue.
|
||||
self.root_event_filter = _RootEventFilter(
|
||||
@ -132,7 +131,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
self.send_stream = memory_stream.get_send_stream()
|
||||
self.receive_stream = memory_stream.get_receive_stream()
|
||||
|
||||
def _get_parent_ids(self, run_id: UUID) -> List[str]:
|
||||
def _get_parent_ids(self, run_id: UUID) -> list[str]:
|
||||
"""Get the parent IDs of a run (non-recursively) cast to strings."""
|
||||
parent_ids = []
|
||||
|
||||
@ -269,8 +268,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
self,
|
||||
run_id: UUID,
|
||||
*,
|
||||
tags: Optional[List[str]],
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
tags: Optional[list[str]],
|
||||
metadata: Optional[dict[str, Any]],
|
||||
parent_run_id: Optional[UUID],
|
||||
name_: str,
|
||||
run_type: str,
|
||||
@ -296,13 +295,13 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -337,13 +336,13 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -384,8 +383,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
data: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Generate a custom astream event."""
|
||||
@ -456,7 +455,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
run_info = self.run_map.pop(run_id)
|
||||
inputs_ = run_info["inputs"]
|
||||
|
||||
generations: Union[List[List[GenerationChunk]], List[List[ChatGenerationChunk]]]
|
||||
generations: Union[list[list[GenerationChunk]], list[list[ChatGenerationChunk]]]
|
||||
output: Union[dict, BaseMessage] = {}
|
||||
|
||||
if run_info["run_type"] == "chat_model":
|
||||
@ -504,13 +503,13 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_type: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
@ -552,10 +551,10 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
|
||||
async def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
outputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""End a trace for a chain run."""
|
||||
@ -586,15 +585,15 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a trace for a tool run."""
|
||||
@ -653,13 +652,13 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
@ -6,7 +6,7 @@ import logging
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langsmith import Client
|
||||
@ -91,7 +91,7 @@ class LangChainTracer(BaseTracer):
|
||||
example_id: Optional[Union[UUID, str]] = None,
|
||||
project_name: Optional[str] = None,
|
||||
client: Optional[Client] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the LangChain tracer.
|
||||
@ -114,13 +114,13 @@ class LangChainTracer(BaseTracer):
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
@ -194,7 +194,7 @@ class LangChainTracer(BaseTracer):
|
||||
)
|
||||
raise ValueError("Failed to get run URL.")
|
||||
|
||||
def _get_tags(self, run: Run) -> List[str]:
|
||||
def _get_tags(self, run: Run) -> list[str]:
|
||||
"""Get combined tags for a run."""
|
||||
tags = set(run.tags or [])
|
||||
tags.update(self.tags or [])
|
||||
|
@ -7,9 +7,7 @@ from collections import defaultdict
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
@ -42,16 +40,16 @@ class LogEntry(TypedDict):
|
||||
"""Name of the object being run."""
|
||||
type: str
|
||||
"""Type of the object being run, eg. prompt, chain, llm, etc."""
|
||||
tags: List[str]
|
||||
tags: list[str]
|
||||
"""List of tags for the run."""
|
||||
metadata: Dict[str, Any]
|
||||
metadata: dict[str, Any]
|
||||
"""Key-value pairs of metadata for the run."""
|
||||
start_time: str
|
||||
"""ISO-8601 timestamp of when the run started."""
|
||||
|
||||
streamed_output_str: List[str]
|
||||
streamed_output_str: list[str]
|
||||
"""List of LLM tokens streamed by this run, if applicable."""
|
||||
streamed_output: List[Any]
|
||||
streamed_output: list[Any]
|
||||
"""List of output chunks streamed by this run, if available."""
|
||||
inputs: NotRequired[Optional[Any]]
|
||||
"""Inputs to this run. Not available currently via astream_log."""
|
||||
@ -69,7 +67,7 @@ class RunState(TypedDict):
|
||||
|
||||
id: str
|
||||
"""ID of the run."""
|
||||
streamed_output: List[Any]
|
||||
streamed_output: list[Any]
|
||||
"""List of output chunks streamed by Runnable.stream()"""
|
||||
final_output: Optional[Any]
|
||||
"""Final output of the run, usually the result of aggregating (`+`) streamed_output.
|
||||
@ -83,7 +81,7 @@ class RunState(TypedDict):
|
||||
# Do we want tags/metadata on the root run? Client kinda knows it in most situations
|
||||
# tags: List[str]
|
||||
|
||||
logs: Dict[str, LogEntry]
|
||||
logs: dict[str, LogEntry]
|
||||
"""Map of run names to sub-runs. If filters were supplied, this list will
|
||||
contain only the runs that matched the filters."""
|
||||
|
||||
@ -91,14 +89,14 @@ class RunState(TypedDict):
|
||||
class RunLogPatch:
|
||||
"""Patch to the run log."""
|
||||
|
||||
ops: List[Dict[str, Any]]
|
||||
ops: list[dict[str, Any]]
|
||||
"""List of jsonpatch operations, which describe how to create the run state
|
||||
from an empty dict. This is the minimal representation of the log, designed to
|
||||
be serialized as JSON and sent over the wire to reconstruct the log on the other
|
||||
side. Reconstruction of the state can be done with any jsonpatch-compliant library,
|
||||
see https://jsonpatch.com for more information."""
|
||||
|
||||
def __init__(self, *ops: Dict[str, Any]) -> None:
|
||||
def __init__(self, *ops: dict[str, Any]) -> None:
|
||||
self.ops = list(ops)
|
||||
|
||||
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog:
|
||||
@ -127,7 +125,7 @@ class RunLog(RunLogPatch):
|
||||
state: RunState
|
||||
"""Current state of the log, obtained from applying all ops in sequence."""
|
||||
|
||||
def __init__(self, *ops: Dict[str, Any], state: RunState) -> None:
|
||||
def __init__(self, *ops: dict[str, Any], state: RunState) -> None:
|
||||
super().__init__(*ops)
|
||||
self.state = state
|
||||
|
||||
@ -219,14 +217,14 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
|
||||
self.lock = threading.Lock()
|
||||
self.send_stream = memory_stream.get_send_stream()
|
||||
self.receive_stream = memory_stream.get_receive_stream()
|
||||
self._key_map_by_run_id: Dict[UUID, str] = {}
|
||||
self._counter_map_by_name: Dict[str, int] = defaultdict(int)
|
||||
self._key_map_by_run_id: dict[UUID, str] = {}
|
||||
self._counter_map_by_name: dict[str, int] = defaultdict(int)
|
||||
self.root_id: Optional[UUID] = None
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[RunLogPatch]:
|
||||
return self.receive_stream.__aiter__()
|
||||
|
||||
def send(self, *ops: Dict[str, Any]) -> bool:
|
||||
def send(self, *ops: dict[str, Any]) -> bool:
|
||||
"""Send a patch to the stream, return False if the stream is closed.
|
||||
|
||||
Args:
|
||||
@ -477,7 +475,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
|
||||
|
||||
def _get_standardized_inputs(
|
||||
run: Run, schema_format: Literal["original", "streaming_events"]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""Extract standardized inputs from a run.
|
||||
|
||||
Standardizes the inputs based on the type of the runnable used.
|
||||
@ -631,7 +629,7 @@ async def _astream_log_implementation(
|
||||
except TypeError:
|
||||
prev_final_output = None
|
||||
final_output = chunk
|
||||
patches: List[Dict[str, Any]] = []
|
||||
patches: list[dict[str, Any]] = []
|
||||
if with_streamed_output_list:
|
||||
patches.append(
|
||||
{
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langsmith.schemas import RunBase as BaseRunV2
|
||||
@ -18,7 +18,7 @@ from langchain_core._api import deprecated
|
||||
|
||||
|
||||
@deprecated("0.1.0", alternative="Use string instead.", removal="1.0")
|
||||
def RunTypeEnum() -> Type[RunTypeEnumDep]:
|
||||
def RunTypeEnum() -> type[RunTypeEnumDep]:
|
||||
"""RunTypeEnum."""
|
||||
warnings.warn(
|
||||
"RunTypeEnum is deprecated. Please directly use a string instead"
|
||||
@ -35,7 +35,7 @@ class TracerSessionV1Base(BaseModelV1):
|
||||
|
||||
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
|
||||
name: Optional[str] = None
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
extra: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
@deprecated("0.1.0", removal="1.0")
|
||||
@ -72,10 +72,10 @@ class BaseRun(BaseModelV1):
|
||||
parent_uuid: Optional[str] = None
|
||||
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
|
||||
end_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
extra: Optional[dict[str, Any]] = None
|
||||
execution_order: int
|
||||
child_execution_order: int
|
||||
serialized: Dict[str, Any]
|
||||
serialized: dict[str, Any]
|
||||
session_id: int
|
||||
error: Optional[str] = None
|
||||
|
||||
@ -84,7 +84,7 @@ class BaseRun(BaseModelV1):
|
||||
class LLMRun(BaseRun):
|
||||
"""Class for LLMRun."""
|
||||
|
||||
prompts: List[str]
|
||||
prompts: list[str]
|
||||
# Temporarily, remove but we will completely remove LLMRun
|
||||
# response: Optional[LLMResult] = None
|
||||
|
||||
@ -93,11 +93,11 @@ class LLMRun(BaseRun):
|
||||
class ChainRun(BaseRun):
|
||||
"""Class for ChainRun."""
|
||||
|
||||
inputs: Dict[str, Any]
|
||||
outputs: Optional[Dict[str, Any]] = None
|
||||
child_llm_runs: List[LLMRun] = FieldV1(default_factory=list)
|
||||
child_chain_runs: List[ChainRun] = FieldV1(default_factory=list)
|
||||
child_tool_runs: List[ToolRun] = FieldV1(default_factory=list)
|
||||
inputs: dict[str, Any]
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
child_llm_runs: list[LLMRun] = FieldV1(default_factory=list)
|
||||
child_chain_runs: list[ChainRun] = FieldV1(default_factory=list)
|
||||
child_tool_runs: list[ToolRun] = FieldV1(default_factory=list)
|
||||
|
||||
|
||||
@deprecated("0.1.0", alternative="Run", removal="1.0")
|
||||
@ -107,9 +107,9 @@ class ToolRun(BaseRun):
|
||||
tool_input: str
|
||||
output: Optional[str] = None
|
||||
action: str
|
||||
child_llm_runs: List[LLMRun] = FieldV1(default_factory=list)
|
||||
child_chain_runs: List[ChainRun] = FieldV1(default_factory=list)
|
||||
child_tool_runs: List[ToolRun] = FieldV1(default_factory=list)
|
||||
child_llm_runs: list[LLMRun] = FieldV1(default_factory=list)
|
||||
child_chain_runs: list[ChainRun] = FieldV1(default_factory=list)
|
||||
child_tool_runs: list[ToolRun] = FieldV1(default_factory=list)
|
||||
|
||||
|
||||
# Begin V2 API Schemas
|
||||
@ -126,9 +126,9 @@ class Run(BaseRunV2):
|
||||
dotted_order: The dotted order.
|
||||
"""
|
||||
|
||||
child_runs: List[Run] = FieldV1(default_factory=list)
|
||||
tags: Optional[List[str]] = FieldV1(default_factory=list)
|
||||
events: List[Dict[str, Any]] = FieldV1(default_factory=list)
|
||||
child_runs: list[Run] = FieldV1(default_factory=list)
|
||||
tags: Optional[list[str]] = FieldV1(default_factory=list)
|
||||
events: list[dict[str, Any]] = FieldV1(default_factory=list)
|
||||
trace_id: Optional[UUID] = None
|
||||
dotted_order: Optional[str] = None
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Merge many dicts, handling specific scenarios where a key exists in both
|
||||
dictionaries but has a value of None in 'left'. In such cases, the method uses the
|
||||
value from 'right' for that key in the merged dictionary.
|
||||
@ -69,7 +69,7 @@ def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]
|
||||
return merged
|
||||
|
||||
|
||||
def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]:
|
||||
def merge_lists(left: Optional[list], *others: Optional[list]) -> Optional[list]:
|
||||
"""Add many lists, handling None.
|
||||
|
||||
Args:
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
|
||||
def env_var_is_set(env_var: str) -> bool:
|
||||
@ -22,8 +22,8 @@ def env_var_is_set(env_var: str) -> bool:
|
||||
|
||||
|
||||
def get_from_dict_or_env(
|
||||
data: Dict[str, Any],
|
||||
key: Union[str, List[str]],
|
||||
data: dict[str, Any],
|
||||
key: Union[str, list[str]],
|
||||
env_key: str,
|
||||
default: Optional[str] = None,
|
||||
) -> str:
|
||||
|
@ -85,7 +85,7 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict:
|
||||
removal="1.0",
|
||||
)
|
||||
def convert_pydantic_to_openai_function(
|
||||
model: Type,
|
||||
model: type,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
@ -106,8 +106,10 @@ def convert_pydantic_to_openai_function(
|
||||
"""
|
||||
if hasattr(model, "model_json_schema"):
|
||||
schema = model.model_json_schema() # Pydantic 2
|
||||
else:
|
||||
elif hasattr(model, "schema"):
|
||||
schema = model.schema() # Pydantic 1
|
||||
else:
|
||||
raise TypeError("Model must be a Pydantic model.")
|
||||
schema = dereference_refs(schema)
|
||||
if "definitions" in schema: # pydantic 1
|
||||
schema.pop("definitions", None)
|
||||
@ -128,7 +130,7 @@ def convert_pydantic_to_openai_function(
|
||||
removal="1.0",
|
||||
)
|
||||
def convert_pydantic_to_openai_tool(
|
||||
model: Type[BaseModel],
|
||||
model: type[BaseModel],
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
@ -194,8 +196,8 @@ def convert_python_function_to_openai_function(
|
||||
)
|
||||
|
||||
|
||||
def _convert_typed_dict_to_openai_function(typed_dict: Type) -> FunctionDescription:
|
||||
visited: Dict = {}
|
||||
def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescription:
|
||||
visited: dict = {}
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
model = cast(
|
||||
@ -209,11 +211,11 @@ _MAX_TYPED_DICT_RECURSION = 25
|
||||
|
||||
|
||||
def _convert_any_typed_dicts_to_pydantic(
|
||||
type_: Type,
|
||||
type_: type,
|
||||
*,
|
||||
visited: Dict,
|
||||
visited: dict,
|
||||
depth: int = 0,
|
||||
) -> Type:
|
||||
) -> type:
|
||||
from pydantic.v1 import Field as Field_v1
|
||||
from pydantic.v1 import create_model as create_model_v1
|
||||
|
||||
@ -267,9 +269,9 @@ def _convert_any_typed_dicts_to_pydantic(
|
||||
subscriptable_origin = _py_38_safe_origin(origin)
|
||||
type_args = tuple(
|
||||
_convert_any_typed_dicts_to_pydantic(arg, depth=depth + 1, visited=visited)
|
||||
for arg in type_args
|
||||
for arg in type_args # type: ignore[index]
|
||||
)
|
||||
return subscriptable_origin[type_args]
|
||||
return subscriptable_origin[type_args] # type: ignore[index]
|
||||
else:
|
||||
return type_
|
||||
|
||||
@ -333,10 +335,10 @@ def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
|
||||
|
||||
|
||||
def convert_to_openai_function(
|
||||
function: Union[Dict[str, Any], Type, Callable, BaseTool],
|
||||
function: Union[dict[str, Any], type, Callable, BaseTool],
|
||||
*,
|
||||
strict: Optional[bool] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Convert a raw function/class to an OpenAI function.
|
||||
|
||||
.. versionchanged:: 0.2.29
|
||||
@ -411,10 +413,10 @@ def convert_to_openai_function(
|
||||
|
||||
|
||||
def convert_to_openai_tool(
|
||||
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
|
||||
tool: Union[dict[str, Any], type[BaseModel], Callable, BaseTool],
|
||||
*,
|
||||
strict: Optional[bool] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Convert a raw function/class to an OpenAI tool.
|
||||
|
||||
.. versionchanged:: 0.2.29
|
||||
@ -441,13 +443,13 @@ def convert_to_openai_tool(
|
||||
if isinstance(tool, dict) and tool.get("type") == "function" and "function" in tool:
|
||||
return tool
|
||||
oai_function = convert_to_openai_function(tool, strict=strict)
|
||||
oai_tool: Dict[str, Any] = {"type": "function", "function": oai_function}
|
||||
oai_tool: dict[str, Any] = {"type": "function", "function": oai_function}
|
||||
return oai_tool
|
||||
|
||||
|
||||
def tool_example_to_messages(
|
||||
input: str, tool_calls: List[BaseModel], tool_outputs: Optional[List[str]] = None
|
||||
) -> List[BaseMessage]:
|
||||
input: str, tool_calls: list[BaseModel], tool_outputs: Optional[list[str]] = None
|
||||
) -> list[BaseMessage]:
|
||||
"""Convert an example into a list of messages that can be fed into an LLM.
|
||||
|
||||
This code is an adapter that converts a single example to a list of messages
|
||||
@ -511,7 +513,7 @@ def tool_example_to_messages(
|
||||
tool_example_to_messages(txt, [tool_call])
|
||||
)
|
||||
"""
|
||||
messages: List[BaseMessage] = [HumanMessage(content=input)]
|
||||
messages: list[BaseMessage] = [HumanMessage(content=input)]
|
||||
openai_tool_calls = []
|
||||
for tool_call in tool_calls:
|
||||
openai_tool_calls.append(
|
||||
@ -540,10 +542,10 @@ def tool_example_to_messages(
|
||||
|
||||
def _parse_google_docstring(
|
||||
docstring: Optional[str],
|
||||
args: List[str],
|
||||
args: list[str],
|
||||
*,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
) -> Tuple[str, dict]:
|
||||
) -> tuple[str, dict]:
|
||||
"""Parse the function and argument descriptions from the docstring of a function.
|
||||
|
||||
Assumes the function docstring follows Google Python style guide.
|
||||
@ -590,12 +592,12 @@ def _parse_google_docstring(
|
||||
return description, arg_descriptions
|
||||
|
||||
|
||||
def _py_38_safe_origin(origin: Type) -> Type:
|
||||
origin_union_type_map: Dict[Type, Any] = (
|
||||
def _py_38_safe_origin(origin: type) -> type:
|
||||
origin_union_type_map: dict[type, Any] = (
|
||||
{types.UnionType: Union} if hasattr(types, "UnionType") else {}
|
||||
)
|
||||
|
||||
origin_map: Dict[Type, Any] = {
|
||||
origin_map: dict[type, Any] = {
|
||||
dict: Dict,
|
||||
list: List,
|
||||
tuple: Tuple,
|
||||
@ -610,8 +612,8 @@ def _py_38_safe_origin(origin: Type) -> Type:
|
||||
|
||||
|
||||
def _recursive_set_additional_properties_false(
|
||||
schema: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
schema: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(schema, dict):
|
||||
# Check if 'required' is a key at the current level or if the schema is empty,
|
||||
# in which case additionalProperties still needs to be specified.
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Callable, List
|
||||
from typing import Any, Callable
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
|
||||
@ -163,7 +163,7 @@ def _parse_json(
|
||||
return parser(json_str)
|
||||
|
||||
|
||||
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
||||
def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict:
|
||||
"""
|
||||
Parse a JSON string from a Markdown string and check that it
|
||||
contains the expected keys.
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
|
||||
def _retrieve_ref(path: str, schema: dict) -> dict:
|
||||
@ -24,9 +24,9 @@ def _retrieve_ref(path: str, schema: dict) -> dict:
|
||||
|
||||
def _dereference_refs_helper(
|
||||
obj: Any,
|
||||
full_schema: Dict[str, Any],
|
||||
full_schema: dict[str, Any],
|
||||
skip_keys: Sequence[str],
|
||||
processed_refs: Optional[Set[str]] = None,
|
||||
processed_refs: Optional[set[str]] = None,
|
||||
) -> Any:
|
||||
if processed_refs is None:
|
||||
processed_refs = set()
|
||||
@ -63,8 +63,8 @@ def _dereference_refs_helper(
|
||||
|
||||
|
||||
def _infer_skip_keys(
|
||||
obj: Any, full_schema: dict, processed_refs: Optional[Set[str]] = None
|
||||
) -> List[str]:
|
||||
obj: Any, full_schema: dict, processed_refs: Optional[set[str]] = None
|
||||
) -> list[str]:
|
||||
if processed_refs is None:
|
||||
processed_refs = set()
|
||||
|
||||
|
@ -16,7 +16,6 @@ from typing import (
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -45,7 +44,7 @@ class ChevronError(SyntaxError):
|
||||
#
|
||||
|
||||
|
||||
def grab_literal(template: str, l_del: str) -> Tuple[str, str]:
|
||||
def grab_literal(template: str, l_del: str) -> tuple[str, str]:
|
||||
"""Parse a literal from the template.
|
||||
|
||||
Args:
|
||||
@ -124,7 +123,7 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], str]:
|
||||
def parse_tag(template: str, l_del: str, r_del: str) -> tuple[tuple[str, str], str]:
|
||||
"""Parse a tag from a template.
|
||||
|
||||
Args:
|
||||
@ -201,7 +200,7 @@ def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], s
|
||||
|
||||
def tokenize(
|
||||
template: str, def_ldel: str = "{{", def_rdel: str = "}}"
|
||||
) -> Iterator[Tuple[str, str]]:
|
||||
) -> Iterator[tuple[str, str]]:
|
||||
"""Tokenize a mustache template.
|
||||
|
||||
Tokenizes a mustache template in a generator fashion,
|
||||
@ -427,13 +426,13 @@ def _get_partial(name: str, partials_dict: Mapping[str, str]) -> str:
|
||||
#
|
||||
# The main rendering function
|
||||
#
|
||||
g_token_cache: Dict[str, List[Tuple[str, str]]] = {}
|
||||
g_token_cache: dict[str, list[tuple[str, str]]] = {}
|
||||
|
||||
EMPTY_DICT: MappingProxyType[str, str] = MappingProxyType({})
|
||||
|
||||
|
||||
def render(
|
||||
template: Union[str, List[Tuple[str, str]]] = "",
|
||||
template: Union[str, list[tuple[str, str]]] = "",
|
||||
data: Mapping[str, Any] = EMPTY_DICT,
|
||||
partials_dict: Mapping[str, str] = EMPTY_DICT,
|
||||
padding: str = "",
|
||||
@ -490,7 +489,7 @@ def render(
|
||||
if isinstance(template, Sequence) and not isinstance(template, str):
|
||||
# Then we don't need to tokenize it
|
||||
# But it does need to be a generator
|
||||
tokens: Iterator[Tuple[str, str]] = (token for token in template)
|
||||
tokens: Iterator[tuple[str, str]] = (token for token in template)
|
||||
else:
|
||||
if template in g_token_cache:
|
||||
tokens = (token for token in g_token_cache[template])
|
||||
@ -561,7 +560,7 @@ def render(
|
||||
if callable(scope):
|
||||
# Generate template text from tags
|
||||
text = ""
|
||||
tags: List[Tuple[str, str]] = []
|
||||
tags: list[tuple[str, str]] = []
|
||||
for token in tokens:
|
||||
if token == ("end", key):
|
||||
break
|
||||
|
@ -11,7 +11,6 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
@ -71,7 +70,7 @@ else:
|
||||
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
|
||||
|
||||
|
||||
def is_pydantic_v1_subclass(cls: Type) -> bool:
|
||||
def is_pydantic_v1_subclass(cls: type) -> bool:
|
||||
"""Check if the installed Pydantic version is 1.x-like."""
|
||||
if PYDANTIC_MAJOR_VERSION == 1:
|
||||
return True
|
||||
@ -83,14 +82,14 @@ def is_pydantic_v1_subclass(cls: Type) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def is_pydantic_v2_subclass(cls: Type) -> bool:
|
||||
def is_pydantic_v2_subclass(cls: type) -> bool:
|
||||
"""Check if the installed Pydantic version is 1.x-like."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
return PYDANTIC_MAJOR_VERSION == 2 and issubclass(cls, BaseModel)
|
||||
|
||||
|
||||
def is_basemodel_subclass(cls: Type) -> bool:
|
||||
def is_basemodel_subclass(cls: type) -> bool:
|
||||
"""Check if the given class is a subclass of Pydantic BaseModel.
|
||||
|
||||
Check if the given class is a subclass of any of the following:
|
||||
@ -166,7 +165,7 @@ def pre_init(func: Callable) -> Any:
|
||||
|
||||
@root_validator(pre=True)
|
||||
@wraps(func)
|
||||
def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Decorator to run a function before model initialization.
|
||||
|
||||
Args:
|
||||
@ -218,12 +217,12 @@ class _IgnoreUnserializable(GenerateJsonSchema):
|
||||
|
||||
def _create_subset_model_v1(
|
||||
name: str,
|
||||
model: Type[BaseModel],
|
||||
model: type[BaseModel],
|
||||
field_names: list,
|
||||
*,
|
||||
descriptions: Optional[dict] = None,
|
||||
fn_description: Optional[str] = None,
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""Create a pydantic model with only a subset of model's fields."""
|
||||
if PYDANTIC_MAJOR_VERSION == 1:
|
||||
from pydantic import create_model
|
||||
@ -256,12 +255,12 @@ def _create_subset_model_v1(
|
||||
|
||||
def _create_subset_model_v2(
|
||||
name: str,
|
||||
model: Type[pydantic.BaseModel],
|
||||
field_names: List[str],
|
||||
model: type[pydantic.BaseModel],
|
||||
field_names: list[str],
|
||||
*,
|
||||
descriptions: Optional[dict] = None,
|
||||
fn_description: Optional[str] = None,
|
||||
) -> Type[pydantic.BaseModel]:
|
||||
) -> type[pydantic.BaseModel]:
|
||||
"""Create a pydantic model with a subset of the model fields."""
|
||||
from pydantic import create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
@ -299,11 +298,11 @@ def _create_subset_model_v2(
|
||||
def _create_subset_model(
|
||||
name: str,
|
||||
model: TypeBaseModel,
|
||||
field_names: List[str],
|
||||
field_names: list[str],
|
||||
*,
|
||||
descriptions: Optional[dict] = None,
|
||||
fn_description: Optional[str] = None,
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""Create subset model using the same pydantic version as the input model."""
|
||||
if PYDANTIC_MAJOR_VERSION == 1:
|
||||
return _create_subset_model_v1(
|
||||
@ -344,25 +343,25 @@ if PYDANTIC_MAJOR_VERSION == 2:
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
@overload
|
||||
def get_fields(model: Type[BaseModelV2]) -> Dict[str, FieldInfoV2]: ...
|
||||
def get_fields(model: type[BaseModelV2]) -> dict[str, FieldInfoV2]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: BaseModelV2) -> Dict[str, FieldInfoV2]: ...
|
||||
def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: Type[BaseModelV1]) -> Dict[str, FieldInfoV1]: ...
|
||||
def get_fields(model: type[BaseModelV1]) -> dict[str, FieldInfoV1]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: BaseModelV1) -> Dict[str, FieldInfoV1]: ...
|
||||
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...
|
||||
|
||||
def get_fields(
|
||||
model: Union[
|
||||
BaseModelV2,
|
||||
BaseModelV1,
|
||||
Type[BaseModelV2],
|
||||
Type[BaseModelV1],
|
||||
type[BaseModelV2],
|
||||
type[BaseModelV1],
|
||||
],
|
||||
) -> Union[Dict[str, FieldInfoV2], Dict[str, FieldInfoV1]]:
|
||||
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
|
||||
"""Get the field names of a Pydantic model."""
|
||||
if hasattr(model, "model_fields"):
|
||||
return model.model_fields # type: ignore
|
||||
@ -375,8 +374,8 @@ elif PYDANTIC_MAJOR_VERSION == 1:
|
||||
from pydantic import BaseModel as BaseModelV1_
|
||||
|
||||
def get_fields( # type: ignore[no-redef]
|
||||
model: Union[Type[BaseModelV1_], BaseModelV1_],
|
||||
) -> Dict[str, FieldInfoV1]:
|
||||
model: Union[type[BaseModelV1_], BaseModelV1_],
|
||||
) -> dict[str, FieldInfoV1]:
|
||||
"""Get the field names of a Pydantic model."""
|
||||
return model.__fields__ # type: ignore
|
||||
else:
|
||||
@ -394,14 +393,14 @@ def _create_root_model(
|
||||
type_: Any,
|
||||
module_name: Optional[str] = None,
|
||||
default_: object = NO_DEFAULT,
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""Create a base class."""
|
||||
|
||||
def schema(
|
||||
cls: Type[BaseModel],
|
||||
cls: type[BaseModel],
|
||||
by_alias: bool = True,
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
# Complains about schema not being defined in superclass
|
||||
schema_ = super(cls, cls).schema( # type: ignore[misc]
|
||||
by_alias=by_alias, ref_template=ref_template
|
||||
@ -410,12 +409,12 @@ def _create_root_model(
|
||||
return schema_
|
||||
|
||||
def model_json_schema(
|
||||
cls: Type[BaseModel],
|
||||
cls: type[BaseModel],
|
||||
by_alias: bool = True,
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
|
||||
mode: JsonSchemaMode = "validation",
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
# Complains about model_json_schema not being defined in superclass
|
||||
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
|
||||
by_alias=by_alias,
|
||||
@ -452,7 +451,7 @@ def _create_root_model_cached(
|
||||
*,
|
||||
module_name: Optional[str] = None,
|
||||
default_: object = NO_DEFAULT,
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
return _create_root_model(
|
||||
model_name, type_, default_=default_, module_name=module_name
|
||||
)
|
||||
@ -462,7 +461,7 @@ def _create_root_model_cached(
|
||||
def _create_model_cached(
|
||||
__model_name: str,
|
||||
**field_definitions: Any,
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
return _create_model_base(
|
||||
__model_name,
|
||||
__config__=_SchemaConfig,
|
||||
@ -474,7 +473,7 @@ def create_model(
|
||||
__model_name: str,
|
||||
__module_name: Optional[str] = None,
|
||||
**field_definitions: Any,
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""Create a pydantic model with the given field definitions.
|
||||
|
||||
Please use create_model_v2 instead of this function.
|
||||
@ -513,7 +512,7 @@ def create_model(
|
||||
_RESERVED_NAMES = {key for key in dir(BaseModel) if not key.startswith("_")}
|
||||
|
||||
|
||||
def _remap_field_definitions(field_definitions: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]:
|
||||
"""This remaps fields to avoid colliding with internal pydantic fields."""
|
||||
from pydantic import Field
|
||||
from pydantic.fields import FieldInfo
|
||||
@ -547,9 +546,9 @@ def create_model_v2(
|
||||
model_name: str,
|
||||
*,
|
||||
module_name: Optional[str] = None,
|
||||
field_definitions: Optional[Dict[str, Any]] = None,
|
||||
field_definitions: Optional[dict[str, Any]] = None,
|
||||
root: Optional[Any] = None,
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
"""Create a pydantic model with the given field definitions.
|
||||
|
||||
Attention:
|
||||
|
@ -32,13 +32,9 @@ from typing import (
|
||||
Callable,
|
||||
ClassVar,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
@ -66,13 +62,13 @@ class VectorStore(ABC):
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
# One of the kwargs should be `ids` which is a list of ids
|
||||
# associated with the texts.
|
||||
# This is not yet enforced in the type signature for backwards compatibility
|
||||
# with existing implementations.
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
@ -124,7 +120,7 @@ class VectorStore(ABC):
|
||||
)
|
||||
return None
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
||||
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[bool]:
|
||||
"""Delete by vector ID or other criteria.
|
||||
|
||||
Args:
|
||||
@ -138,7 +134,7 @@ class VectorStore(ABC):
|
||||
|
||||
raise NotImplementedError("delete method must be implemented by subclass.")
|
||||
|
||||
def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
|
||||
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
|
||||
"""Get documents by their IDs.
|
||||
|
||||
The returned documents are expected to have the ID field set to the ID of the
|
||||
@ -167,7 +163,7 @@ class VectorStore(ABC):
|
||||
)
|
||||
|
||||
# Implementations should override this method to provide an async native version.
|
||||
async def aget_by_ids(self, ids: Sequence[str], /) -> List[Document]:
|
||||
async def aget_by_ids(self, ids: Sequence[str], /) -> list[Document]:
|
||||
"""Async get documents by their IDs.
|
||||
|
||||
The returned documents are expected to have the ID field set to the ID of the
|
||||
@ -194,7 +190,7 @@ class VectorStore(ABC):
|
||||
return await run_in_executor(None, self.get_by_ids, ids)
|
||||
|
||||
async def adelete(
|
||||
self, ids: Optional[List[str]] = None, **kwargs: Any
|
||||
self, ids: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> Optional[bool]:
|
||||
"""Async delete by vector ID or other criteria.
|
||||
|
||||
@ -211,9 +207,9 @@ class VectorStore(ABC):
|
||||
async def aadd_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Async run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
@ -254,7 +250,7 @@ class VectorStore(ABC):
|
||||
return await self.aadd_documents(docs, **kwargs)
|
||||
return await run_in_executor(None, self.add_texts, texts, metadatas, **kwargs)
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
def add_documents(self, documents: list[Document], **kwargs: Any) -> list[str]:
|
||||
"""Add or update documents in the vectorstore.
|
||||
|
||||
Args:
|
||||
@ -287,8 +283,8 @@ class VectorStore(ABC):
|
||||
)
|
||||
|
||||
async def aadd_documents(
|
||||
self, documents: List[Document], **kwargs: Any
|
||||
) -> List[str]:
|
||||
self, documents: list[Document], **kwargs: Any
|
||||
) -> list[str]:
|
||||
"""Async run more documents through the embeddings and add to
|
||||
the vectorstore.
|
||||
|
||||
@ -318,7 +314,7 @@ class VectorStore(ABC):
|
||||
|
||||
return await run_in_executor(None, self.add_documents, documents, **kwargs)
|
||||
|
||||
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
|
||||
def search(self, query: str, search_type: str, **kwargs: Any) -> list[Document]:
|
||||
"""Return docs most similar to query using a specified search type.
|
||||
|
||||
Args:
|
||||
@ -352,7 +348,7 @@ class VectorStore(ABC):
|
||||
|
||||
async def asearch(
|
||||
self, query: str, search_type: str, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Async return docs most similar to query using a specified search type.
|
||||
|
||||
Args:
|
||||
@ -386,7 +382,7 @@ class VectorStore(ABC):
|
||||
@abstractmethod
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
@ -442,7 +438,7 @@ class VectorStore(ABC):
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
) -> list[tuple[Document, float]]:
|
||||
"""Run similarity search with distance.
|
||||
|
||||
Args:
|
||||
@ -456,7 +452,7 @@ class VectorStore(ABC):
|
||||
|
||||
async def asimilarity_search_with_score(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
) -> list[tuple[Document, float]]:
|
||||
"""Async run similarity search with distance.
|
||||
|
||||
Args:
|
||||
@ -479,7 +475,7 @@ class VectorStore(ABC):
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
) -> list[tuple[Document, float]]:
|
||||
"""
|
||||
Default similarity search with relevance scores. Modify if necessary
|
||||
in subclass.
|
||||
@ -506,7 +502,7 @@ class VectorStore(ABC):
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
) -> list[tuple[Document, float]]:
|
||||
"""
|
||||
Default similarity search with relevance scores. Modify if necessary
|
||||
in subclass.
|
||||
@ -533,7 +529,7 @@ class VectorStore(ABC):
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
) -> list[tuple[Document, float]]:
|
||||
"""Return docs and relevance scores in the range [0, 1].
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
@ -581,7 +577,7 @@ class VectorStore(ABC):
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
) -> list[tuple[Document, float]]:
|
||||
"""Async return docs and relevance scores in the range [0, 1].
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
@ -626,7 +622,7 @@ class VectorStore(ABC):
|
||||
|
||||
async def asimilarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Async return docs most similar to query.
|
||||
|
||||
Args:
|
||||
@ -644,8 +640,8 @@ class VectorStore(ABC):
|
||||
return await run_in_executor(None, self.similarity_search, query, k=k, **kwargs)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self, embedding: List[float], k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
self, embedding: list[float], k: int = 4, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
@ -659,8 +655,8 @@ class VectorStore(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
async def asimilarity_search_by_vector(
|
||||
self, embedding: List[float], k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
self, embedding: list[float], k: int = 4, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
"""Async return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
@ -686,7 +682,7 @@ class VectorStore(ABC):
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
@ -715,7 +711,7 @@ class VectorStore(ABC):
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Async return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
@ -750,12 +746,12 @@ class VectorStore(ABC):
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
@ -779,12 +775,12 @@ class VectorStore(ABC):
|
||||
|
||||
async def amax_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Async return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
@ -816,8 +812,8 @@ class VectorStore(ABC):
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls: Type[VST],
|
||||
documents: List[Document],
|
||||
cls: type[VST],
|
||||
documents: list[Document],
|
||||
embedding: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
@ -837,8 +833,8 @@ class VectorStore(ABC):
|
||||
|
||||
@classmethod
|
||||
async def afrom_documents(
|
||||
cls: Type[VST],
|
||||
documents: List[Document],
|
||||
cls: type[VST],
|
||||
documents: list[Document],
|
||||
embedding: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
@ -859,10 +855,10 @@ class VectorStore(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_texts(
|
||||
cls: Type[VST],
|
||||
texts: List[str],
|
||||
cls: type[VST],
|
||||
texts: list[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
"""Return VectorStore initialized from texts and embeddings.
|
||||
@ -880,10 +876,10 @@ class VectorStore(ABC):
|
||||
|
||||
@classmethod
|
||||
async def afrom_texts(
|
||||
cls: Type[VST],
|
||||
texts: List[str],
|
||||
cls: type[VST],
|
||||
texts: list[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
"""Async return VectorStore initialized from texts and embeddings.
|
||||
@ -902,7 +898,7 @@ class VectorStore(ABC):
|
||||
None, cls.from_texts, texts, embedding, metadatas, **kwargs
|
||||
)
|
||||
|
||||
def _get_retriever_tags(self) -> List[str]:
|
||||
def _get_retriever_tags(self) -> list[str]:
|
||||
"""Get tags for retriever."""
|
||||
tags = [self.__class__.__name__]
|
||||
if self.embeddings:
|
||||
@ -991,7 +987,7 @@ class VectorStoreRetriever(BaseRetriever):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_search_type(cls, values: Dict) -> Any:
|
||||
def validate_search_type(cls, values: dict) -> Any:
|
||||
"""Validate search type.
|
||||
|
||||
Args:
|
||||
@ -1040,7 +1036,7 @@ class VectorStoreRetriever(BaseRetriever):
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||
elif self.search_type == "similarity_score_threshold":
|
||||
@ -1060,7 +1056,7 @@ class VectorStoreRetriever(BaseRetriever):
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = await self.vectorstore.asimilarity_search(
|
||||
query, **self.search_kwargs
|
||||
@ -1080,7 +1076,7 @@ class VectorStoreRetriever(BaseRetriever):
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
def add_documents(self, documents: list[Document], **kwargs: Any) -> list[str]:
|
||||
"""Add documents to the vectorstore.
|
||||
|
||||
Args:
|
||||
@ -1093,8 +1089,8 @@ class VectorStoreRetriever(BaseRetriever):
|
||||
return self.vectorstore.add_documents(documents, **kwargs)
|
||||
|
||||
async def aadd_documents(
|
||||
self, documents: List[Document], **kwargs: Any
|
||||
) -> List[str]:
|
||||
self, documents: list[Document], **kwargs: Any
|
||||
) -> list[str]:
|
||||
"""Async add documents to the vectorstore.
|
||||
|
||||
Args:
|
||||
|
@ -7,12 +7,9 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
@ -153,7 +150,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
"""
|
||||
# TODO: would be nice to change to
|
||||
# Dict[str, Document] at some point (will be a breaking change)
|
||||
self.store: Dict[str, Dict[str, Any]] = {}
|
||||
self.store: dict[str, dict[str, Any]] = {}
|
||||
self.embedding = embedding
|
||||
|
||||
@property
|
||||
@ -170,10 +167,10 @@ class InMemoryVectorStore(VectorStore):
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
documents: List[Document],
|
||||
ids: Optional[List[str]] = None,
|
||||
documents: list[Document],
|
||||
ids: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Add documents to the store."""
|
||||
texts = [doc.page_content for doc in documents]
|
||||
vectors = self.embedding.embed_documents(texts)
|
||||
@ -204,8 +201,8 @@ class InMemoryVectorStore(VectorStore):
|
||||
return ids_
|
||||
|
||||
async def aadd_documents(
|
||||
self, documents: List[Document], ids: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> List[str]:
|
||||
self, documents: list[Document], ids: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> list[str]:
|
||||
"""Add documents to the store."""
|
||||
texts = [doc.page_content for doc in documents]
|
||||
vectors = await self.embedding.aembed_documents(texts)
|
||||
@ -219,7 +216,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
id_iterator: Iterator[Optional[str]] = (
|
||||
iter(ids) if ids else iter(doc.id for doc in documents)
|
||||
)
|
||||
ids_: List[str] = []
|
||||
ids_: list[str] = []
|
||||
|
||||
for doc, vector in zip(documents, vectors):
|
||||
doc_id = next(id_iterator)
|
||||
@ -234,7 +231,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
|
||||
return ids_
|
||||
|
||||
def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
|
||||
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
|
||||
"""Get documents by their ids.
|
||||
|
||||
Args:
|
||||
@ -313,7 +310,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
"failed": [],
|
||||
}
|
||||
|
||||
async def aget_by_ids(self, ids: Sequence[str], /) -> List[Document]:
|
||||
async def aget_by_ids(self, ids: Sequence[str], /) -> list[Document]:
|
||||
"""Async get documents by their ids.
|
||||
|
||||
Args:
|
||||
@ -326,11 +323,11 @@ class InMemoryVectorStore(VectorStore):
|
||||
|
||||
def _similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Callable[[Document], bool]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float, List[float]]]:
|
||||
) -> list[tuple[Document, float, list[float]]]:
|
||||
result = []
|
||||
for doc in self.store.values():
|
||||
vector = doc["vector"]
|
||||
@ -351,11 +348,11 @@ class InMemoryVectorStore(VectorStore):
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Callable[[Document], bool]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
) -> list[tuple[Document, float]]:
|
||||
return [
|
||||
(doc, similarity)
|
||||
for doc, similarity, _ in self._similarity_search_with_score_by_vector(
|
||||
@ -368,7 +365,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
) -> list[tuple[Document, float]]:
|
||||
embedding = self.embedding.embed_query(query)
|
||||
docs = self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
@ -379,7 +376,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
|
||||
async def asimilarity_search_with_score(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
) -> list[tuple[Document, float]]:
|
||||
embedding = await self.embedding.aembed_query(query)
|
||||
docs = self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
@ -390,10 +387,10 @@ class InMemoryVectorStore(VectorStore):
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
@ -402,18 +399,18 @@ class InMemoryVectorStore(VectorStore):
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
async def asimilarity_search_by_vector(
|
||||
self, embedding: List[float], k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
self, embedding: list[float], k: int = 4, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
return self.similarity_search_by_vector(embedding, k, **kwargs)
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
return [doc for doc, _ in self.similarity_search_with_score(query, k, **kwargs)]
|
||||
|
||||
async def asimilarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
return [
|
||||
doc
|
||||
for doc, _ in await self.asimilarity_search_with_score(query, k, **kwargs)
|
||||
@ -421,12 +418,12 @@ class InMemoryVectorStore(VectorStore):
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
prefetch_hits = self._similarity_search_with_score_by_vector(
|
||||
embedding=embedding,
|
||||
k=fetch_k,
|
||||
@ -456,7 +453,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
embedding_vector = self.embedding.embed_query(query)
|
||||
return self.max_marginal_relevance_search_by_vector(
|
||||
embedding_vector,
|
||||
@ -473,7 +470,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
embedding_vector = await self.embedding.aembed_query(query)
|
||||
return self.max_marginal_relevance_search_by_vector(
|
||||
embedding_vector,
|
||||
@ -486,9 +483,9 @@ class InMemoryVectorStore(VectorStore):
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
texts: list[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> InMemoryVectorStore:
|
||||
store = cls(
|
||||
@ -500,9 +497,9 @@ class InMemoryVectorStore(VectorStore):
|
||||
@classmethod
|
||||
async def afrom_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
texts: list[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> InMemoryVectorStore:
|
||||
store = cls(
|
||||
|
@ -76,7 +76,7 @@ def maximal_marginal_relevance(
|
||||
embedding_list: list,
|
||||
lambda_mult: float = 0.5,
|
||||
k: int = 4,
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
"""Calculate maximal marginal relevance.
|
||||
|
||||
Args:
|
||||
|
8
libs/core/poetry.lock
generated
8
libs/core/poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
@ -1186,7 +1186,7 @@ develop = true
|
||||
|
||||
[package.dependencies]
|
||||
httpx = "^0.27.0"
|
||||
langchain-core = ">=0.3.0.dev1"
|
||||
langchain-core = "^0.3.0"
|
||||
pytest = ">=7,<9"
|
||||
syrupy = "^4"
|
||||
|
||||
@ -1196,7 +1196,7 @@ url = "../standard-tests"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-text-splitters"
|
||||
version = "0.3.0.dev1"
|
||||
version = "0.3.0"
|
||||
description = "LangChain text splitting utilities"
|
||||
optional = false
|
||||
python-versions = ">=3.9,<4.0"
|
||||
@ -1204,7 +1204,7 @@ files = []
|
||||
develop = true
|
||||
|
||||
[package.dependencies]
|
||||
langchain-core = "^0.3.0.dev1"
|
||||
langchain-core = "^0.3.0"
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
|
@ -41,8 +41,8 @@ python = ">=3.12.4"
|
||||
[tool.poetry.extras]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [ "B", "E", "F", "I", "T201", "UP",]
|
||||
ignore = [ "UP006", "UP007",]
|
||||
select = ["B", "E", "F", "I", "T201", "UP"]
|
||||
ignore = ["UP007"]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [ "tests/*",]
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@ -26,7 +26,7 @@ class FakeAsyncTracer(AsyncBaseTracer):
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the tracer."""
|
||||
super().__init__()
|
||||
self.runs: List[Run] = []
|
||||
self.runs: list[Run] = []
|
||||
|
||||
async def _persist_run(self, run: Run) -> None:
|
||||
self.runs.append(run)
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
@ -30,7 +30,7 @@ class FakeTracer(BaseTracer):
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the tracer."""
|
||||
super().__init__()
|
||||
self.runs: List[Run] = []
|
||||
self.runs: list[Run] = []
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
|
@ -7,7 +7,7 @@ the relevant methods.
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence
|
||||
from typing import Any, Iterable, Optional, Sequence
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
@ -18,19 +18,19 @@ class CustomAddTextsVectorstore(VectorStore):
|
||||
"""A vectorstore that only implements add texts."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.store: Dict[str, Document] = {}
|
||||
self.store: dict[str, Document] = {}
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
# One of the kwargs should be `ids` which is a list of ids
|
||||
# associated with the texts.
|
||||
# This is not yet enforced in the type signature for backwards compatibility
|
||||
# with existing implementations.
|
||||
ids: Optional[List[str]] = None,
|
||||
ids: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
if not isinstance(texts, list):
|
||||
texts = list(texts)
|
||||
ids_iter = iter(ids or [])
|
||||
@ -46,14 +46,14 @@ class CustomAddTextsVectorstore(VectorStore):
|
||||
ids_.append(id_)
|
||||
return ids_
|
||||
|
||||
def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
|
||||
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
|
||||
return [self.store[id] for id in ids if id in self.store]
|
||||
|
||||
def from_texts( # type: ignore
|
||||
cls,
|
||||
texts: List[str],
|
||||
texts: list[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> CustomAddTextsVectorstore:
|
||||
vectorstore = CustomAddTextsVectorstore()
|
||||
@ -62,7 +62,7 @@ class CustomAddTextsVectorstore(VectorStore):
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user