mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 21:11:43 +00:00
chore(langchain): add ruff rule UP007 in langchain_v1
(#32811)
Done by autofix
This commit is contained in:
committed by
GitHub
parent
54c2419a4e
commit
fe6c415c9f
@@ -1,13 +1,12 @@
|
||||
"""Lazy import utilities."""
|
||||
|
||||
from importlib import import_module
|
||||
from typing import Union
|
||||
|
||||
|
||||
def import_attr(
|
||||
attr_name: str,
|
||||
module_name: Union[str, None],
|
||||
package: Union[str, None],
|
||||
module_name: str | None,
|
||||
package: str | None,
|
||||
) -> object:
|
||||
"""Import an attribute from a module located in a package.
|
||||
|
||||
|
@@ -12,7 +12,7 @@ particularly for summarization chains and other document processing workflows.
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
@@ -24,11 +24,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def resolve_prompt(
|
||||
prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||
],
|
||||
prompt: str | None | Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
default_user_content: str,
|
||||
@@ -89,12 +85,10 @@ def resolve_prompt(
|
||||
|
||||
|
||||
async def aresolve_prompt(
|
||||
prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||
Callable[[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]]],
|
||||
],
|
||||
prompt: str
|
||||
| None
|
||||
| Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]]
|
||||
| Callable[[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]]],
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
default_user_content: str,
|
||||
|
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, TypeVar
|
||||
|
||||
from langgraph.graph._node import StateNode
|
||||
from pydantic import BaseModel
|
||||
@@ -44,7 +44,7 @@ class DataclassLike(Protocol):
|
||||
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
|
||||
|
||||
|
||||
StateLike: TypeAlias = Union[TypedDictLikeV1, TypedDictLikeV2, DataclassLike, BaseModel]
|
||||
StateLike: TypeAlias = TypedDictLikeV1 | TypedDictLikeV2 | DataclassLike | BaseModel
|
||||
"""Type alias for state-like types.
|
||||
|
||||
It can either be a ``TypedDict``, ``dataclass``, or Pydantic ``BaseModel``.
|
||||
@@ -58,7 +58,7 @@ It can either be a ``TypedDict``, ``dataclass``, or Pydantic ``BaseModel``.
|
||||
StateT = TypeVar("StateT", bound=StateLike)
|
||||
"""Type variable used to represent the state in a graph."""
|
||||
|
||||
ContextT = TypeVar("ContextT", bound=Union[StateLike, None])
|
||||
ContextT = TypeVar("ContextT", bound=StateLike | None)
|
||||
"""Type variable for context types."""
|
||||
|
||||
|
||||
|
@@ -3,11 +3,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TypeVar, Union
|
||||
from typing import TypeVar
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
SyncOrAsync = Callable[P, Union[R, Awaitable[R]]]
|
||||
SyncOrAsync = Callable[P, R | Awaitable[R]]
|
||||
|
@@ -1,6 +1,6 @@
|
||||
"""Interrupt types to use with agent inbox like setups."""
|
||||
|
||||
from typing import Literal, Union
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@@ -89,4 +89,4 @@ class HumanResponse(TypedDict):
|
||||
"""
|
||||
|
||||
type: Literal["accept", "ignore", "response", "edit"]
|
||||
args: Union[None, str, ActionRequest]
|
||||
args: None | str | ActionRequest
|
||||
|
@@ -2,7 +2,7 @@
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
||||
@@ -59,7 +59,7 @@ def _filter_state_for_schema(state: dict[str, Any], schema: type) -> dict[str, A
|
||||
return {k: v for k, v in state.items() if k in schema_fields}
|
||||
|
||||
|
||||
def _supports_native_structured_output(model: Union[str, BaseChatModel]) -> bool:
|
||||
def _supports_native_structured_output(model: str | BaseChatModel) -> bool:
|
||||
"""Check if a model supports native structured output."""
|
||||
model_name: str | None = None
|
||||
if isinstance(model, str):
|
||||
|
@@ -11,7 +11,6 @@ from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
Union,
|
||||
cast,
|
||||
get_type_hints,
|
||||
)
|
||||
@@ -104,12 +103,12 @@ class AgentStateWithStructuredResponsePydantic(AgentStatePydantic, Generic[Struc
|
||||
|
||||
PROMPT_RUNNABLE_NAME = "Prompt"
|
||||
|
||||
Prompt = Union[
|
||||
SystemMessage,
|
||||
str,
|
||||
Callable[[StateT], LanguageModelInput],
|
||||
Runnable[StateT, LanguageModelInput],
|
||||
]
|
||||
Prompt = (
|
||||
SystemMessage
|
||||
| str
|
||||
| Callable[[StateT], LanguageModelInput]
|
||||
| Runnable[StateT, LanguageModelInput]
|
||||
)
|
||||
|
||||
|
||||
def _get_state_value(state: StateT, key: str, default: Any = None) -> Any:
|
||||
@@ -189,12 +188,8 @@ class _AgentBuilder(Generic[StateT, ContextT, StructuredResponseT]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[
|
||||
str,
|
||||
BaseChatModel,
|
||||
SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel],
|
||||
],
|
||||
tools: Union[Sequence[Union[BaseTool, Callable, dict[str, Any]]], ToolNode],
|
||||
model: str | BaseChatModel | SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel],
|
||||
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode,
|
||||
*,
|
||||
prompt: Prompt | None = None,
|
||||
response_format: ResponseFormat[StructuredResponseT] | None = None,
|
||||
@@ -691,10 +686,10 @@ class _AgentBuilder(Generic[StateT, ContextT, StructuredResponseT]):
|
||||
return CallModelInputSchema
|
||||
return self._final_state_schema
|
||||
|
||||
def create_model_router(self) -> Callable[[StateT], Union[str, list[Send]]]:
|
||||
def create_model_router(self) -> Callable[[StateT], str | list[Send]]:
|
||||
"""Create routing function for model node conditional edges."""
|
||||
|
||||
def should_continue(state: StateT) -> Union[str, list[Send]]:
|
||||
def should_continue(state: StateT) -> str | list[Send]:
|
||||
messages = _get_state_value(state, "messages")
|
||||
last_message = messages[-1]
|
||||
|
||||
@@ -731,10 +726,10 @@ class _AgentBuilder(Generic[StateT, ContextT, StructuredResponseT]):
|
||||
|
||||
def create_post_model_hook_router(
|
||||
self,
|
||||
) -> Callable[[StateT], Union[str, list[Send]]]:
|
||||
) -> Callable[[StateT], str | list[Send]]:
|
||||
"""Create a routing function for post_model_hook node conditional edges."""
|
||||
|
||||
def post_model_hook_router(state: StateT) -> Union[str, list[Send]]:
|
||||
def post_model_hook_router(state: StateT) -> str | list[Send]:
|
||||
messages = _get_state_value(state, "messages")
|
||||
|
||||
# Check if the last message is a ToolMessage from a structured tool.
|
||||
@@ -882,7 +877,7 @@ class _AgentBuilder(Generic[StateT, ContextT, StructuredResponseT]):
|
||||
|
||||
|
||||
def _supports_native_structured_output(
|
||||
model: Union[str, BaseChatModel, SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel]],
|
||||
model: str | BaseChatModel | SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel],
|
||||
) -> bool:
|
||||
"""Check if a model supports native structured output.
|
||||
|
||||
@@ -903,20 +898,14 @@ def _supports_native_structured_output(
|
||||
|
||||
|
||||
def create_agent( # noqa: D417
|
||||
model: Union[
|
||||
str,
|
||||
BaseChatModel,
|
||||
SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel],
|
||||
],
|
||||
tools: Union[Sequence[Union[BaseTool, Callable, dict[str, Any]]], ToolNode],
|
||||
model: str | BaseChatModel | SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel],
|
||||
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode,
|
||||
*,
|
||||
middleware: Sequence[AgentMiddleware] = (),
|
||||
prompt: Prompt | None = None,
|
||||
response_format: Union[
|
||||
ToolStrategy[StructuredResponseT],
|
||||
ProviderStrategy[StructuredResponseT],
|
||||
type[StructuredResponseT],
|
||||
]
|
||||
response_format: ToolStrategy[StructuredResponseT]
|
||||
| ProviderStrategy[StructuredResponseT]
|
||||
| type[StructuredResponseT]
|
||||
| None = None,
|
||||
pre_model_hook: RunnableLike | None = None,
|
||||
post_model_hook: RunnableLike | None = None,
|
||||
@@ -1175,7 +1164,7 @@ def create_agent( # noqa: D417
|
||||
model=model,
|
||||
tools=tools,
|
||||
prompt=prompt,
|
||||
response_format=cast("Union[ResponseFormat[StructuredResponseT], None]", response_format),
|
||||
response_format=cast("ResponseFormat[StructuredResponseT] | None", response_format),
|
||||
pre_model_hook=pre_model_hook,
|
||||
post_model_hook=post_model_hook,
|
||||
state_schema=state_schema,
|
||||
|
@@ -67,7 +67,7 @@ class StructuredOutputValidationError(StructuredOutputError):
|
||||
|
||||
|
||||
def _parse_with_schema(
|
||||
schema: Union[type[SchemaT], dict], schema_kind: SchemaKind, data: dict[str, Any]
|
||||
schema: type[SchemaT] | dict, schema_kind: SchemaKind, data: dict[str, Any]
|
||||
) -> Any:
|
||||
"""Parse data using for any supported schema type.
|
||||
|
||||
@@ -180,13 +180,9 @@ class ToolStrategy(Generic[SchemaT]):
|
||||
tool_message_content: str | None
|
||||
"""The content of the tool message to be returned when the model calls an artificial structured output tool."""
|
||||
|
||||
handle_errors: Union[
|
||||
bool,
|
||||
str,
|
||||
type[Exception],
|
||||
tuple[type[Exception], ...],
|
||||
Callable[[Exception], str],
|
||||
]
|
||||
handle_errors: (
|
||||
bool | str | type[Exception] | tuple[type[Exception], ...] | Callable[[Exception], str]
|
||||
)
|
||||
"""Error handling strategy for structured output via ToolStrategy. Default is True.
|
||||
|
||||
- True: Catch all errors with default error template
|
||||
@@ -202,13 +198,11 @@ class ToolStrategy(Generic[SchemaT]):
|
||||
schema: type[SchemaT],
|
||||
*,
|
||||
tool_message_content: str | None = None,
|
||||
handle_errors: Union[
|
||||
bool,
|
||||
str,
|
||||
type[Exception],
|
||||
tuple[type[Exception], ...],
|
||||
Callable[[Exception], str],
|
||||
] = True,
|
||||
handle_errors: bool
|
||||
| str
|
||||
| type[Exception]
|
||||
| tuple[type[Exception], ...]
|
||||
| Callable[[Exception], str] = True,
|
||||
) -> None:
|
||||
"""Initialize ToolStrategy with schemas, tool message content, and error handling strategy."""
|
||||
self.schema = schema
|
||||
@@ -400,4 +394,4 @@ class ProviderStrategyBinding(Generic[SchemaT]):
|
||||
return str(content)
|
||||
|
||||
|
||||
ResponseFormat = Union[ToolStrategy[SchemaT], ProviderStrategy[SchemaT]]
|
||||
ResponseFormat = ToolStrategy[SchemaT] | ProviderStrategy[SchemaT]
|
||||
|
@@ -91,7 +91,7 @@ TOOL_EXECUTION_ERROR_TEMPLATE = "Error executing tool '{tool_name}' with kwargs
|
||||
TOOL_INVOCATION_ERROR_TEMPLATE = "Error invoking tool '{tool_name}' with kwargs {tool_kwargs} with error:\n {error}\n Please fix the error and try again."
|
||||
|
||||
|
||||
def msg_content_output(output: Any) -> Union[str, list[dict]]:
|
||||
def msg_content_output(output: Any) -> str | list[dict]:
|
||||
"""Convert tool output to valid message content format.
|
||||
|
||||
LangChain ToolMessages accept either string content or a list of content blocks.
|
||||
@@ -161,13 +161,7 @@ def _default_handle_tool_errors(e: Exception) -> str:
|
||||
def _handle_tool_error(
|
||||
e: Exception,
|
||||
*,
|
||||
flag: Union[
|
||||
bool,
|
||||
str,
|
||||
Callable[..., str],
|
||||
type[Exception],
|
||||
tuple[type[Exception], ...],
|
||||
],
|
||||
flag: bool | str | Callable[..., str] | type[Exception] | tuple[type[Exception], ...],
|
||||
) -> str:
|
||||
"""Generate error message content based on exception handling configuration.
|
||||
|
||||
@@ -380,13 +374,15 @@ class ToolNode(RunnableCallable):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: Sequence[Union[BaseTool, Callable]],
|
||||
tools: Sequence[BaseTool | Callable],
|
||||
*,
|
||||
name: str = "tools",
|
||||
tags: list[str] | None = None,
|
||||
handle_tool_errors: Union[
|
||||
bool, str, Callable[..., str], type[Exception], tuple[type[Exception], ...]
|
||||
] = _default_handle_tool_errors,
|
||||
handle_tool_errors: bool
|
||||
| str
|
||||
| Callable[..., str]
|
||||
| type[Exception]
|
||||
| tuple[type[Exception], ...] = _default_handle_tool_errors,
|
||||
messages_key: str = "messages",
|
||||
) -> None:
|
||||
"""Initialize the ToolNode with the provided tools and configuration.
|
||||
@@ -420,11 +416,7 @@ class ToolNode(RunnableCallable):
|
||||
|
||||
def _func(
|
||||
self,
|
||||
input: Union[
|
||||
list[AnyMessage],
|
||||
dict[str, Any],
|
||||
BaseModel,
|
||||
],
|
||||
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
store: Optional[BaseStore], # noqa: UP045
|
||||
@@ -439,11 +431,7 @@ class ToolNode(RunnableCallable):
|
||||
|
||||
async def _afunc(
|
||||
self,
|
||||
input: Union[
|
||||
list[AnyMessage],
|
||||
dict[str, Any],
|
||||
BaseModel,
|
||||
],
|
||||
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
store: Optional[BaseStore], # noqa: UP045
|
||||
@@ -457,9 +445,9 @@ class ToolNode(RunnableCallable):
|
||||
|
||||
def _combine_tool_outputs(
|
||||
self,
|
||||
outputs: list[Union[ToolMessage, Command]],
|
||||
outputs: list[ToolMessage | Command],
|
||||
input_type: Literal["list", "dict", "tool_calls"],
|
||||
) -> list[Union[Command, list[ToolMessage], dict[str, list[ToolMessage]]]]:
|
||||
) -> list[Command | list[ToolMessage] | dict[str, list[ToolMessage]]]:
|
||||
# preserve existing behavior for non-command tool outputs for backwards
|
||||
# compatibility
|
||||
if not any(isinstance(output, Command) for output in outputs):
|
||||
@@ -502,7 +490,7 @@ class ToolNode(RunnableCallable):
|
||||
call: ToolCall,
|
||||
input_type: Literal["list", "dict", "tool_calls"],
|
||||
config: RunnableConfig,
|
||||
) -> Union[ToolMessage, Command]:
|
||||
) -> ToolMessage | Command:
|
||||
"""Run a single tool call synchronously."""
|
||||
if invalid_tool_message := self._validate_tool_call(call):
|
||||
return invalid_tool_message
|
||||
@@ -556,7 +544,7 @@ class ToolNode(RunnableCallable):
|
||||
if isinstance(response, Command):
|
||||
return self._validate_tool_command(response, call, input_type)
|
||||
if isinstance(response, ToolMessage):
|
||||
response.content = cast("Union[str, list]", msg_content_output(response.content))
|
||||
response.content = cast("str | list", msg_content_output(response.content))
|
||||
return response
|
||||
msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
|
||||
raise TypeError(msg)
|
||||
@@ -566,7 +554,7 @@ class ToolNode(RunnableCallable):
|
||||
call: ToolCall,
|
||||
input_type: Literal["list", "dict", "tool_calls"],
|
||||
config: RunnableConfig,
|
||||
) -> Union[ToolMessage, Command]:
|
||||
) -> ToolMessage | Command:
|
||||
"""Run a single tool call asynchronously."""
|
||||
if invalid_tool_message := self._validate_tool_call(call):
|
||||
return invalid_tool_message
|
||||
@@ -621,18 +609,14 @@ class ToolNode(RunnableCallable):
|
||||
if isinstance(response, Command):
|
||||
return self._validate_tool_command(response, call, input_type)
|
||||
if isinstance(response, ToolMessage):
|
||||
response.content = cast("Union[str, list]", msg_content_output(response.content))
|
||||
response.content = cast("str | list", msg_content_output(response.content))
|
||||
return response
|
||||
msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
def _parse_input(
|
||||
self,
|
||||
input: Union[
|
||||
list[AnyMessage],
|
||||
dict[str, Any],
|
||||
BaseModel,
|
||||
],
|
||||
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
||||
store: BaseStore | None,
|
||||
) -> tuple[list[ToolCall], Literal["list", "dict", "tool_calls"]]:
|
||||
input_type: Literal["list", "dict", "tool_calls"]
|
||||
@@ -679,11 +663,7 @@ class ToolNode(RunnableCallable):
|
||||
def _inject_state(
|
||||
self,
|
||||
tool_call: ToolCall,
|
||||
input: Union[
|
||||
list[AnyMessage],
|
||||
dict[str, Any],
|
||||
BaseModel,
|
||||
],
|
||||
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
||||
) -> ToolCall:
|
||||
state_args = self._tool_to_state_args[tool_call["name"]]
|
||||
if state_args and isinstance(input, list):
|
||||
@@ -740,11 +720,7 @@ class ToolNode(RunnableCallable):
|
||||
def inject_tool_args(
|
||||
self,
|
||||
tool_call: ToolCall,
|
||||
input: Union[
|
||||
list[AnyMessage],
|
||||
dict[str, Any],
|
||||
BaseModel,
|
||||
],
|
||||
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
||||
store: BaseStore | None,
|
||||
) -> ToolCall:
|
||||
"""Inject graph state and store into tool call arguments.
|
||||
@@ -853,7 +829,7 @@ class ToolNode(RunnableCallable):
|
||||
|
||||
|
||||
def tools_condition(
|
||||
state: Union[list[AnyMessage], dict[str, Any], BaseModel],
|
||||
state: list[AnyMessage] | dict[str, Any] | BaseModel,
|
||||
messages_key: str = "messages",
|
||||
) -> Literal["tools", "__end__"]:
|
||||
"""Conditional routing function for tool-calling workflows.
|
||||
@@ -1086,7 +1062,7 @@ class InjectedStore(InjectedToolArg):
|
||||
"""
|
||||
|
||||
|
||||
def _is_injection(type_arg: Any, injection_type: type[Union[InjectedState, InjectedStore]]) -> bool:
|
||||
def _is_injection(type_arg: Any, injection_type: type[InjectedState | InjectedStore]) -> bool:
|
||||
"""Check if a type argument represents an injection annotation.
|
||||
|
||||
This utility function determines whether a type annotation indicates that
|
||||
|
@@ -9,7 +9,6 @@ from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
TypeAlias,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
@@ -55,7 +54,7 @@ def init_chat_model(
|
||||
model: str | None = None,
|
||||
*,
|
||||
model_provider: str | None = None,
|
||||
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = ...,
|
||||
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = ...,
|
||||
config_prefix: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> _ConfigurableModel: ...
|
||||
@@ -68,10 +67,10 @@ def init_chat_model(
|
||||
model: str | None = None,
|
||||
*,
|
||||
model_provider: str | None = None,
|
||||
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] | None = None,
|
||||
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] | None = None,
|
||||
config_prefix: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[BaseChatModel, _ConfigurableModel]:
|
||||
) -> BaseChatModel | _ConfigurableModel:
|
||||
"""Initialize a ChatModel from the model name and provider.
|
||||
|
||||
**Note:** Must have the integration package corresponding to the model provider
|
||||
@@ -531,12 +530,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
self,
|
||||
*,
|
||||
default_config: dict | None = None,
|
||||
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = "any",
|
||||
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = "any",
|
||||
config_prefix: str = "",
|
||||
queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
|
||||
) -> None:
|
||||
self._default_config: dict = default_config or {}
|
||||
self._configurable_fields: Union[Literal["any"], list[str]] = (
|
||||
self._configurable_fields: Literal["any"] | list[str] = (
|
||||
configurable_fields if configurable_fields == "any" else list(configurable_fields)
|
||||
)
|
||||
self._config_prefix = (
|
||||
@@ -639,11 +638,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
# This is a version of LanguageModelInput which replaces the abstract
|
||||
# base class BaseMessage with a union of its subclasses, which makes
|
||||
# for a much better schema.
|
||||
return Union[
|
||||
str,
|
||||
Union[StringPromptValue, ChatPromptValueConcrete],
|
||||
list[AnyMessage],
|
||||
]
|
||||
return str | StringPromptValue | ChatPromptValueConcrete | list[AnyMessage]
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
@@ -685,7 +680,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def batch(
|
||||
self,
|
||||
inputs: list[LanguageModelInput],
|
||||
config: Union[RunnableConfig, list[RunnableConfig]] | None = None,
|
||||
config: RunnableConfig | list[RunnableConfig] | None = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any | None,
|
||||
@@ -713,7 +708,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: list[LanguageModelInput],
|
||||
config: Union[RunnableConfig, list[RunnableConfig]] | None = None,
|
||||
config: RunnableConfig | list[RunnableConfig] | None = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any | None,
|
||||
@@ -741,11 +736,11 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def batch_as_completed(
|
||||
self,
|
||||
inputs: Sequence[LanguageModelInput],
|
||||
config: Union[RunnableConfig, Sequence[RunnableConfig]] | None = None,
|
||||
config: RunnableConfig | Sequence[RunnableConfig] | None = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[tuple[int, Union[Any, Exception]]]:
|
||||
) -> Iterator[tuple[int, Any | Exception]]:
|
||||
config = config or None
|
||||
# If <= 1 config use the underlying models batch implementation.
|
||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||
@@ -770,7 +765,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
async def abatch_as_completed(
|
||||
self,
|
||||
inputs: Sequence[LanguageModelInput],
|
||||
config: Union[RunnableConfig, Sequence[RunnableConfig]] | None = None,
|
||||
config: RunnableConfig | Sequence[RunnableConfig] | None = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
@@ -868,7 +863,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
exclude_types: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
|
||||
) -> AsyncIterator[RunLogPatch] | AsyncIterator[RunLog]:
|
||||
async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
|
||||
input,
|
||||
config=config,
|
||||
@@ -916,7 +911,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
# Explicitly added to satisfy downstream linters.
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, AIMessage]:
|
||||
return self.__getattr__("bind_tools")(tools, **kwargs)
|
||||
@@ -924,7 +919,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
# Explicitly added to satisfy downstream linters.
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Union[dict, type[BaseModel]],
|
||||
schema: dict | type[BaseModel],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
) -> Runnable[LanguageModelInput, dict | BaseModel]:
|
||||
return self.__getattr__("with_structured_output")(schema, **kwargs)
|
||||
|
@@ -2,7 +2,7 @@
|
||||
|
||||
import functools
|
||||
from importlib import util
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.runnables import Runnable
|
||||
@@ -126,7 +126,7 @@ def init_embeddings(
|
||||
*,
|
||||
provider: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[Embeddings, Runnable[Any, list[float]]]:
|
||||
) -> Embeddings | Runnable[Any, list[float]]:
|
||||
"""Initialize an embeddings model from a model name and optional provider.
|
||||
|
||||
**Note:** Must have the integration package corresponding to the model provider
|
||||
|
@@ -13,7 +13,7 @@ import hashlib
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Literal, Union, cast
|
||||
from typing import TYPE_CHECKING, Literal, cast
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils.iter import batch_iterate
|
||||
@@ -178,7 +178,7 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
Returns:
|
||||
A list of embeddings for the given texts.
|
||||
"""
|
||||
vectors: list[Union[list[float], None]] = self.document_embedding_store.mget(
|
||||
vectors: list[list[float] | None] = self.document_embedding_store.mget(
|
||||
texts,
|
||||
)
|
||||
all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
|
||||
@@ -210,7 +210,7 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
Returns:
|
||||
A list of embeddings for the given texts.
|
||||
"""
|
||||
vectors: list[Union[list[float], None]] = await self.document_embedding_store.amget(texts)
|
||||
vectors: list[list[float] | None] = await self.document_embedding_store.amget(texts)
|
||||
all_missing_indices: list[int] = [i for i, vector in enumerate(vectors) if vector is None]
|
||||
|
||||
# batch_iterate supports None batch_size which returns all elements at once
|
||||
@@ -285,11 +285,8 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
*,
|
||||
namespace: str = "",
|
||||
batch_size: int | None = None,
|
||||
query_embedding_cache: Union[bool, ByteStore] = False,
|
||||
key_encoder: Union[
|
||||
Callable[[str], str],
|
||||
Literal["sha1", "blake2b", "sha256", "sha512"],
|
||||
] = "sha1",
|
||||
query_embedding_cache: bool | ByteStore = False,
|
||||
key_encoder: Callable[[str], str] | Literal["sha1", "blake2b", "sha256", "sha512"] = "sha1",
|
||||
) -> CacheBackedEmbeddings:
|
||||
"""On-ramp that adds the necessary serialization and encoding to the store.
|
||||
|
||||
|
@@ -4,7 +4,6 @@ from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.stores import BaseStore
|
||||
@@ -106,7 +105,7 @@ class EncoderBackedStore(BaseStore[K, V]):
|
||||
self,
|
||||
*,
|
||||
prefix: str | None = None,
|
||||
) -> Union[Iterator[K], Iterator[str]]:
|
||||
) -> Iterator[K] | Iterator[str]:
|
||||
"""Get an iterator over keys that match the given prefix."""
|
||||
# For the time being this does not return K, but str
|
||||
# it's for debugging purposes. Should fix this.
|
||||
@@ -116,7 +115,7 @@ class EncoderBackedStore(BaseStore[K, V]):
|
||||
self,
|
||||
*,
|
||||
prefix: str | None = None,
|
||||
) -> Union[AsyncIterator[K], AsyncIterator[str]]:
|
||||
) -> AsyncIterator[K] | AsyncIterator[str]:
|
||||
"""Get an iterator over keys that match the given prefix."""
|
||||
# For the time being this does not return K, but str
|
||||
# it's for debugging purposes. Should fix this.
|
||||
|
@@ -114,7 +114,6 @@ ignore = [
|
||||
"ISC001", # Messes with the formatter
|
||||
"PERF203", # Rarely useful
|
||||
"SLF001", # Private member access
|
||||
"UP007", # pyupgrade: non-pep604-annotation-union
|
||||
"PLC0415", # Imports should be at the top. Not always desirable
|
||||
"PLR0913", # Too many arguments in function definition
|
||||
"PLC0414", # Inconsistent with how type checkers expect to be notified of intentional re-exports
|
||||
|
Reference in New Issue
Block a user