From 31eeb50ce03a863270dbad6d973910f0055d17ac Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Wed, 8 Oct 2025 21:17:53 -0400 Subject: [PATCH] chore: drop UP045 (#33362) Python 3.9 EOL --- AGENTS.md | 2 +- CLAUDE.md | 2 +- libs/core/langchain_core/prompt_values.py | 4 +- .../langchain_classic/chains/moderation.py | 2 +- .../langchain_classic/chat_models/base.py | 62 +-- .../langchain_anthropic/_client_utils.py | 10 +- .../anthropic/langchain_anthropic/_compat.py | 4 +- .../langchain_anthropic/chat_models.py | 95 +++-- .../langchain_anthropic/experimental.py | 3 +- .../anthropic/langchain_anthropic/llms.py | 38 +- .../langchain_anthropic/output_parsers.py | 6 +- libs/partners/anthropic/pyproject.toml | 2 - .../integration_tests/test_chat_models.py | 26 +- .../anthropic/tests/unit_tests/_utils.py | 4 +- .../tests/unit_tests/test_chat_models.py | 10 +- libs/partners/anthropic/uv.lock | 4 +- .../chroma/langchain_chroma/__init__.py | 2 + .../chroma/langchain_chroma/vectorstores.py | 163 ++++---- libs/partners/chroma/pyproject.toml | 76 ++-- libs/partners/chroma/scripts/check_imports.py | 2 +- .../integration_tests/test_vectorstores.py | 16 +- .../chroma/tests/unit_tests/test_standard.py | 2 +- libs/partners/chroma/uv.lock | 2 +- .../deepseek/langchain_deepseek/__init__.py | 2 + .../langchain_deepseek/chat_models.py | 47 ++- libs/partners/deepseek/pyproject.toml | 58 +-- .../deepseek/scripts/check_imports.py | 4 +- libs/partners/deepseek/tests/__init__.py | 1 + .../tests/integration_tests/__init__.py | 1 + .../integration_tests/test_chat_models.py | 11 +- .../tests/integration_tests/test_compile.py | 2 + .../deepseek/tests/unit_tests/__init__.py | 1 + .../tests/unit_tests/test_chat_models.py | 17 +- libs/partners/deepseek/uv.lock | 6 +- libs/partners/exa/langchain_exa/__init__.py | 2 + libs/partners/exa/langchain_exa/retrievers.py | 30 +- libs/partners/exa/langchain_exa/tools.py | 70 ++-- libs/partners/exa/pyproject.toml | 12 +- libs/partners/exa/scripts/check_imports.py | 2 + libs/partners/exa/tests/__init__.py | 1 + .../exa/tests/integration_tests/__init__.py | 1 + .../tests/integration_tests/test_compile.py | 2 + .../test_find_similar_tool.py | 3 + .../tests/integration_tests/test_retriever.py | 5 + .../integration_tests/test_search_tool.py | 4 + .../partners/exa/tests/unit_tests/__init__.py | 1 + .../exa/tests/unit_tests/test_imports.py | 3 + .../exa/tests/unit_tests/test_standard.py | 2 + .../langchain_fireworks/chat_models.py | 58 ++- .../fireworks/langchain_fireworks/llms.py | 30 +- libs/partners/fireworks/pyproject.toml | 2 - .../integration_tests/test_chat_models.py | 4 +- .../groq/langchain_groq/chat_models.py | 72 ++-- libs/partners/groq/pyproject.toml | 2 - .../integration_tests/test_chat_models.py | 6 +- .../groq/tests/unit_tests/fake/callbacks.py | 6 +- .../chat_models/huggingface.py | 74 ++-- .../embeddings/huggingface.py | 4 +- .../embeddings/huggingface_endpoint.py | 14 +- .../llms/huggingface_endpoint.py | 52 +-- .../llms/huggingface_pipeline.py | 24 +- .../utils/import_utils.py | 3 +- libs/partners/huggingface/pyproject.toml | 2 - .../langchain_mistralai/chat_models.py | 60 ++- libs/partners/mistralai/pyproject.toml | 2 - .../integration_tests/test_chat_models.py | 4 +- .../nomic/langchain_nomic/embeddings.py | 26 +- libs/partners/nomic/pyproject.toml | 2 - .../ollama/langchain_ollama/_utils.py | 7 +- .../ollama/langchain_ollama/chat_models.py | 118 +++--- .../ollama/langchain_ollama/embeddings.py | 48 +-- libs/partners/ollama/langchain_ollama/llms.py | 92 ++--- libs/partners/ollama/pyproject.toml | 2 - .../chat_models/test_chat_models.py | 10 +- .../chat_models/_client_utils.py | 14 +- .../langchain_openai/chat_models/_compat.py | 4 +- .../langchain_openai/chat_models/azure.py | 40 +- .../langchain_openai/chat_models/base.py | 195 +++++---- .../langchain_openai/embeddings/azure.py | 18 +- .../langchain_openai/embeddings/base.py | 56 +-- .../openai/langchain_openai/llms/azure.py | 22 +- .../openai/langchain_openai/llms/base.py | 62 +-- libs/partners/openai/pyproject.toml | 2 - .../chat_models/test_azure.py | 14 +- .../chat_models/test_base.py | 26 +- .../chat_models/test_responses_api.py | 34 +- .../tests/unit_tests/chat_models/test_base.py | 34 +- .../chat_models/test_responses_stream.py | 4 +- .../openai/tests/unit_tests/fake/callbacks.py | 6 +- .../langchain_perplexity/chat_models.py | 33 +- libs/partners/perplexity/pyproject.toml | 1 - .../tests/unit_tests/test_chat_models.py | 10 +- .../prompty/langchain_prompty/core.py | 6 +- .../prompty/langchain_prompty/parsers.py | 3 +- .../prompty/langchain_prompty/utils.py | 6 +- libs/partners/prompty/pyproject.toml | 1 - .../tests/unit_tests/fake_callback_handler.py | 6 +- .../tests/unit_tests/fake_chat_model.py | 10 +- .../tests/unit_tests/fake_output_parser.py | 6 +- .../qdrant/langchain_qdrant/_utils.py | 8 +- .../langchain_qdrant/fastembed_sparse.py | 10 +- .../qdrant/langchain_qdrant/qdrant.py | 160 ++++---- .../qdrant/langchain_qdrant/vectorstores.py | 372 +++++++++--------- libs/partners/qdrant/pyproject.toml | 2 - .../async_api/test_add_texts.py | 3 +- .../async_api/test_from_texts.py | 13 +- .../async_api/test_max_marginal_relevance.py | 4 +- .../async_api/test_similarity_search.py | 18 +- .../qdrant_vector_store/test_add_texts.py | 3 +- .../qdrant_vector_store/test_from_texts.py | 3 +- .../tests/integration_tests/test_add_texts.py | 5 +- .../test_embedding_interface.py | 6 +- .../integration_tests/test_from_texts.py | 15 +- .../test_max_marginal_relevance.py | 4 +- .../test_similarity_search.py | 20 +- libs/partners/qdrant/uv.lock | 66 ++-- .../partners/xai/langchain_xai/chat_models.py | 47 ++- libs/partners/xai/pyproject.toml | 2 - .../test_chat_models_standard.py | 4 +- 119 files changed, 1423 insertions(+), 1504 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 3f5117c38d7..64fb862b18a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -149,7 +149,7 @@ def send_email(to: str, msg: str, *, priority: str = "normal") -> bool: Args: to: The email address of the recipient. msg: The message body to send. - priority: Email priority level (``'low'``, ``'normal'``, ``'high'``). + priority: Email priority level (`'low'`, ``'normal'``, `'high'`). Returns: True if email was sent successfully, False otherwise. diff --git a/CLAUDE.md b/CLAUDE.md index 3f5117c38d7..64fb862b18a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -149,7 +149,7 @@ def send_email(to: str, msg: str, *, priority: str = "normal") -> bool: Args: to: The email address of the recipient. msg: The message body to send. - priority: Email priority level (``'low'``, ``'normal'``, ``'high'``). + priority: Email priority level (`'low'`, ``'normal'``, `'high'`). Returns: True if email was sent successfully, False otherwise. diff --git a/libs/core/langchain_core/prompt_values.py b/libs/core/langchain_core/prompt_values.py index 5f2ac7fdce5..f572ce0eaf4 100644 --- a/libs/core/langchain_core/prompt_values.py +++ b/libs/core/langchain_core/prompt_values.py @@ -113,8 +113,8 @@ class ImageURL(TypedDict, total=False): """Image URL.""" detail: Literal["auto", "low", "high"] - """Specifies the detail level of the image. Defaults to ``'auto'``. - Can be ``'auto'``, ``'low'``, or ``'high'``. + """Specifies the detail level of the image. Defaults to `'auto'`. + Can be `'auto'`, `'low'`, or `'high'`. This follows OpenAI's Chat Completion API's image URL format. diff --git a/libs/langchain/langchain_classic/chains/moderation.py b/libs/langchain/langchain_classic/chains/moderation.py index f009ac6c90a..b3f2c21ad6c 100644 --- a/libs/langchain/langchain_classic/chains/moderation.py +++ b/libs/langchain/langchain_classic/chains/moderation.py @@ -16,7 +16,7 @@ from langchain_classic.chains.base import Chain class OpenAIModerationChain(Chain): """Pass input through a moderation endpoint. - To use, you should have the ``openai`` python package installed, and the + To use, you should have the `openai` python package installed, and the environment variable ``OPENAI_API_KEY`` set with your API key. Any parameters that are valid to be passed to the openai.create call can be passed diff --git a/libs/langchain/langchain_classic/chat_models/base.py b/libs/langchain/langchain_classic/chat_models/base.py index d451d30f19f..d6e031437af 100644 --- a/libs/langchain/langchain_classic/chat_models/base.py +++ b/libs/langchain/langchain_classic/chat_models/base.py @@ -91,41 +91,41 @@ def init_chat_model( above). Supported model_provider values and the corresponding integration package are: - - ``openai`` -> `langchain-openai` - - ``anthropic`` -> `langchain-anthropic` - - ``azure_openai`` -> `langchain-openai` - - ``azure_ai`` -> ``langchain-azure-ai`` - - ``google_vertexai`` -> ``langchain-google-vertexai`` - - ``google_genai`` -> ``langchain-google-genai`` - - ``bedrock`` -> ``langchain-aws`` - - ``bedrock_converse`` -> ``langchain-aws`` - - ``cohere`` -> ``langchain-cohere`` - - ``fireworks`` -> ``langchain-fireworks`` - - ``together`` -> ``langchain-together`` - - ``mistralai`` -> ``langchain-mistralai`` - - ``huggingface`` -> ``langchain-huggingface`` - - ``groq`` -> ``langchain-groq`` - - ``ollama`` -> ``langchain-ollama`` - - ``google_anthropic_vertex`` -> ``langchain-google-vertexai`` - - ``deepseek`` -> ``langchain-deepseek`` - - ``ibm`` -> ``langchain-ibm`` - - ``nvidia`` -> ``langchain-nvidia-ai-endpoints`` - - ``xai`` -> ``langchain-xai`` - - ``perplexity`` -> ``langchain-perplexity`` + - `openai` -> `langchain-openai` + - `anthropic` -> `langchain-anthropic` + - `azure_openai` -> `langchain-openai` + - `azure_ai` -> `langchain-azure-ai` + - `google_vertexai` -> `langchain-google-vertexai` + - `google_genai` -> `langchain-google-genai` + - `bedrock` -> `langchain-aws` + - `bedrock_converse` -> `langchain-aws` + - `cohere` -> `langchain-cohere` + - `fireworks` -> `langchain-fireworks` + - `together` -> `langchain-together` + - `mistralai` -> `langchain-mistralai` + - `huggingface` -> `langchain-huggingface` + - `groq` -> `langchain-groq` + - `ollama` -> `langchain-ollama` + - `google_anthropic_vertex` -> `langchain-google-vertexai` + - `deepseek` -> `langchain-deepseek` + - `ibm` -> `langchain-ibm` + - `nvidia` -> `langchain-nvidia-ai-endpoints` + - `xai` -> `langchain-xai` + - `perplexity` -> `langchain-perplexity` Will attempt to infer model_provider from model if not specified. The following providers will be inferred based on these model prefixes: - - ``gpt-...`` | ``o1...`` | ``o3...`` -> ``openai`` - - ``claude...`` -> ``anthropic`` - - ``amazon...`` -> ``bedrock`` - - ``gemini...`` -> ``google_vertexai`` - - ``command...`` -> ``cohere`` - - ``accounts/fireworks...`` -> ``fireworks`` - - ``mistral...`` -> ``mistralai`` - - ``deepseek...`` -> ``deepseek`` - - ``grok...`` -> ``xai`` - - ``sonar...`` -> ``perplexity`` + - `gpt-...` | `o1...` | `o3...` -> `openai` + - `claude...` -> `anthropic` + - `amazon...` -> `bedrock` + - `gemini...` -> `google_vertexai` + - `command...` -> `cohere` + - `accounts/fireworks...` -> `fireworks` + - `mistral...` -> `mistralai` + - `deepseek...` -> `deepseek` + - `grok...` -> `xai` + - `sonar...` -> `perplexity` configurable_fields: Which model parameters are configurable: - None: No configurable fields. diff --git a/libs/partners/anthropic/langchain_anthropic/_client_utils.py b/libs/partners/anthropic/langchain_anthropic/_client_utils.py index 53fd3c801b1..55f194f4dbe 100644 --- a/libs/partners/anthropic/langchain_anthropic/_client_utils.py +++ b/libs/partners/anthropic/langchain_anthropic/_client_utils.py @@ -11,7 +11,7 @@ from __future__ import annotations import asyncio import os from functools import lru_cache -from typing import Any, Optional +from typing import Any import anthropic @@ -48,9 +48,9 @@ class _AsyncHttpxClientWrapper(anthropic.DefaultAsyncHttpxClient): @lru_cache def _get_default_httpx_client( *, - base_url: Optional[str], + base_url: str | None, timeout: Any = _NOT_GIVEN, - anthropic_proxy: Optional[str] = None, + anthropic_proxy: str | None = None, ) -> _SyncHttpxClientWrapper: kwargs: dict[str, Any] = { "base_url": base_url @@ -67,9 +67,9 @@ def _get_default_httpx_client( @lru_cache def _get_default_async_httpx_client( *, - base_url: Optional[str], + base_url: str | None, timeout: Any = _NOT_GIVEN, - anthropic_proxy: Optional[str] = None, + anthropic_proxy: str | None = None, ) -> _AsyncHttpxClientWrapper: kwargs: dict[str, Any] = { "base_url": base_url diff --git a/libs/partners/anthropic/langchain_anthropic/_compat.py b/libs/partners/anthropic/langchain_anthropic/_compat.py index 8093598df97..c86ceef3dc0 100644 --- a/libs/partners/anthropic/langchain_anthropic/_compat.py +++ b/libs/partners/anthropic/langchain_anthropic/_compat.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import Any, Optional, cast +from typing import Any, cast from langchain_core.messages import content as types @@ -97,7 +97,7 @@ def _convert_annotation_from_v1(annotation: types.Annotation) -> dict[str, Any]: def _convert_from_v1_to_anthropic( content: list[types.ContentBlock], tool_calls: list[types.ToolCall], - model_provider: Optional[str], + model_provider: str | None, ) -> list[dict[str, Any]]: new_content: list = [] for block in content: diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 1f99f97d6a7..99c91a0075f 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -9,7 +9,7 @@ import warnings from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence from functools import cached_property from operator import itemgetter -from typing import Any, Final, Literal, Optional, Union, cast +from typing import Any, Final, Literal, cast import anthropic from langchain_core.callbacks import ( @@ -170,7 +170,7 @@ def _format_image(url: str) -> dict: def _merge_messages( messages: Sequence[BaseMessage], -) -> list[Union[SystemMessage, AIMessage, HumanMessage]]: +) -> list[SystemMessage | AIMessage | HumanMessage]: """Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501 merged: list = [] for curr in messages: @@ -344,9 +344,9 @@ def _format_data_content_block(block: dict) -> dict: def _format_messages( messages: Sequence[BaseMessage], -) -> tuple[Union[str, list[dict], None], list[dict]]: +) -> tuple[str | list[dict] | None, list[dict]]: """Format messages for Anthropic's API.""" - system: Union[str, list[dict], None] = None + system: str | list[dict] | None = None formatted_messages: list[dict] = [] merged_messages = _merge_messages(messages) for _i, message in enumerate(merged_messages): @@ -368,7 +368,7 @@ def _format_messages( continue role = _message_type_lookups[message.type] - content: Union[str, list] + content: str | list if not isinstance(message.content, str): # parse as dict @@ -1419,29 +1419,29 @@ class ChatAnthropic(BaseChatModel): model: str = Field(alias="model_name") """Model name to use.""" - max_tokens: Optional[int] = Field(default=None, alias="max_tokens_to_sample") + max_tokens: int | None = Field(default=None, alias="max_tokens_to_sample") """Denotes the number of tokens to predict per generation.""" - temperature: Optional[float] = None + temperature: float | None = None """A non-negative float that tunes the degree of randomness in generation.""" - top_k: Optional[int] = None + top_k: int | None = None """Number of most likely tokens to consider at each step.""" - top_p: Optional[float] = None + top_p: float | None = None """Total probability mass of tokens to consider at each step.""" - default_request_timeout: Optional[float] = Field(None, alias="timeout") + default_request_timeout: float | None = Field(None, alias="timeout") """Timeout for requests to Anthropic Completion API.""" # sdk default = 2: https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#retries max_retries: int = 2 """Number of retries allowed for requests sent to the Anthropic Completion API.""" - stop_sequences: Optional[list[str]] = Field(None, alias="stop") + stop_sequences: list[str] | None = Field(None, alias="stop") """Default stop sequences.""" - anthropic_api_url: Optional[str] = Field( + anthropic_api_url: str | None = Field( alias="base_url", default_factory=from_env( ["ANTHROPIC_API_URL", "ANTHROPIC_BASE_URL"], @@ -1462,7 +1462,7 @@ class ChatAnthropic(BaseChatModel): ) """Automatically read from env var ``ANTHROPIC_API_KEY`` if not provided.""" - anthropic_proxy: Optional[str] = Field( + anthropic_proxy: str | None = Field( default_factory=from_env("ANTHROPIC_PROXY", default=None) ) """Proxy to use for the Anthropic clients, will be used for every API call. @@ -1470,10 +1470,10 @@ class ChatAnthropic(BaseChatModel): If not provided, will attempt to read from the ``ANTHROPIC_PROXY`` environment variable.""" - default_headers: Optional[Mapping[str, str]] = None + default_headers: Mapping[str, str] | None = None """Headers to pass to the Anthropic clients, will be used for every API call.""" - betas: Optional[list[str]] = None + betas: list[str] | None = None """List of beta features to enable. If specified, invocations will be routed through client.beta.messages.create. @@ -1490,18 +1490,18 @@ class ChatAnthropic(BaseChatModel): message chunks will be generated during the stream including usage metadata. """ - thinking: Optional[dict[str, Any]] = Field(default=None) + thinking: dict[str, Any] | None = Field(default=None) """Parameters for Claude reasoning, e.g., ``{"type": "enabled", "budget_tokens": 10_000}``""" - mcp_servers: Optional[list[dict[str, Any]]] = None + mcp_servers: list[dict[str, Any]] | None = None """List of MCP servers to use for the request. Example: ``mcp_servers=[{"type": "url", "url": "https://mcp.example.com/mcp", "name": "example-mcp"}]`` """ - context_management: Optional[dict[str, Any]] = None + context_management: dict[str, Any] | None = None """Configuration for `context management `__. """ @@ -1547,7 +1547,7 @@ class ChatAnthropic(BaseChatModel): def _get_ls_params( self, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> LangSmithParams: """Get standard params for tracing.""" @@ -1630,7 +1630,7 @@ class ChatAnthropic(BaseChatModel): self, input_: LanguageModelInput, *, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: dict, ) -> dict: """Get the request payload for the Anthropic API.""" @@ -1717,10 +1717,10 @@ class ChatAnthropic(BaseChatModel): def _stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, *, - stream_usage: Optional[bool] = None, + stream_usage: bool | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: if stream_usage is None: @@ -1753,10 +1753,10 @@ class ChatAnthropic(BaseChatModel): async def _astream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, *, - stream_usage: Optional[bool] = None, + stream_usage: bool | None = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: if stream_usage is None: @@ -1839,8 +1839,8 @@ class ChatAnthropic(BaseChatModel): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: payload = self._get_request_payload(messages, stop=stop, **kwargs) @@ -1853,8 +1853,8 @@ class ChatAnthropic(BaseChatModel): async def _agenerate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: payload = self._get_request_payload(messages, stop=stop, **kwargs) @@ -1866,7 +1866,7 @@ class ChatAnthropic(BaseChatModel): def _get_llm_for_structured_output_when_thinking_is_enabled( self, - schema: Union[dict, type], + schema: dict | type, formatted_tool: AnthropicTool, ) -> Runnable[LanguageModelInput, BaseMessage]: thinking_admonition = ( @@ -1894,12 +1894,10 @@ class ChatAnthropic(BaseChatModel): def bind_tools( self, - tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]], + tools: Sequence[dict[str, Any] | type | Callable | BaseTool], *, - tool_choice: Optional[ - Union[dict[str, str], Literal["any", "auto"], str] # noqa: PYI051 - ] = None, - parallel_tool_calls: Optional[bool] = None, + tool_choice: dict[str, str] | str | None = None, + parallel_tool_calls: bool | None = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, AIMessage]: r"""Bind tool-like objects to this chat model. @@ -1910,15 +1908,14 @@ class ChatAnthropic(BaseChatModel): by `langchain_core.utils.function_calling.convert_to_openai_tool`. tool_choice: Which tool to require the model to call. Options are: - - name of the tool as a string or as dict ``{"type": "tool", "name": "<>"}``: calls corresponding tool; - - ``'auto'``, ``{"type: "auto"}``, or `None`: automatically selects a tool (including no tool); - - `'any'` or ``{"type: "any"}``: force at least one tool to be called; + - name of the tool as a string or as dict `{"type": "tool", "name": "<>"}`: calls corresponding tool; + - `'auto'`, `{"type: "auto"}`, or `None`: automatically selects a tool (including no tool); + - `'any'` or `{"type: "any"}`: force at least one tool to be called; parallel_tool_calls: Set to `False` to disable parallel tool use. Defaults to `None` (no specification, which allows parallel tool use). !!! version-added "Added in version 0.3.2" - kwargs: Any additional parameters are passed directly to - `langchain_anthropic.chat_models.ChatAnthropic.bind`. + kwargs: Any additional parameters are passed directly to `bind`. Example: @@ -2170,11 +2167,11 @@ class ChatAnthropic(BaseChatModel): def with_structured_output( self, - schema: Union[dict, type], + schema: dict | type, *, include_raw: bool = False, **kwargs: Any, - ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: + ) -> Runnable[LanguageModelInput, dict | BaseModel]: """Model wrapper that returns outputs formatted to match the given schema. Args: @@ -2340,9 +2337,7 @@ class ChatAnthropic(BaseChatModel): def get_num_tokens_from_messages( self, messages: list[BaseMessage], - tools: Optional[ - Sequence[Union[dict[str, Any], type, Callable, BaseTool]] - ] = None, + tools: Sequence[dict[str, Any] | type | Callable | BaseTool] | None = None, **kwargs: Any, ) -> int: """Count tokens in a sequence of input messages. @@ -2430,7 +2425,7 @@ class ChatAnthropic(BaseChatModel): def convert_to_anthropic_tool( - tool: Union[dict[str, Any], type, Callable, BaseTool], + tool: dict[str, Any] | type | Callable | BaseTool, ) -> AnthropicTool: """Convert a tool-like object to an Anthropic tool definition.""" # already in Anthropic tool format @@ -2500,8 +2495,8 @@ def _make_message_chunk_from_anthropic_event( *, stream_usage: bool = True, coerce_content_to_string: bool, - block_start_event: Optional[anthropic.types.RawMessageStreamEvent] = None, -) -> tuple[Optional[AIMessageChunk], Optional[anthropic.types.RawMessageStreamEvent]]: + block_start_event: anthropic.types.RawMessageStreamEvent | None = None, +) -> tuple[AIMessageChunk | None, anthropic.types.RawMessageStreamEvent | None]: """Convert Anthropic streaming event to `AIMessageChunk`. Args: @@ -2526,7 +2521,7 @@ def _make_message_chunk_from_anthropic_event( updating the `block_start_event` for context tracking. """ - message_chunk: Optional[AIMessageChunk] = None + message_chunk: AIMessageChunk | None = None # Reference: Anthropic SDK streaming implementation # https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501 diff --git a/libs/partners/anthropic/langchain_anthropic/experimental.py b/libs/partners/anthropic/langchain_anthropic/experimental.py index 041b49befec..d0f9d8d039e 100644 --- a/libs/partners/anthropic/langchain_anthropic/experimental.py +++ b/libs/partners/anthropic/langchain_anthropic/experimental.py @@ -5,7 +5,6 @@ from __future__ import annotations import json from typing import ( Any, - Union, ) SYSTEM_PROMPT_FORMAT = """In this environment you have access to a set of tools you can use to answer the user's question. @@ -83,7 +82,7 @@ def get_system_message(tools: list[dict]) -> str: return SYSTEM_PROMPT_FORMAT.format(formatted_tools=tools_formatted) -def _xml_to_dict(t: Any) -> Union[str, dict[str, Any]]: +def _xml_to_dict(t: Any) -> str | dict[str, Any]: # Base case: If the element has no children, return its text or an empty string. if len(t) == 0: return t.text or "" diff --git a/libs/partners/anthropic/langchain_anthropic/llms.py b/libs/partners/anthropic/langchain_anthropic/llms.py index 5132b050a82..6c7f19e8e13 100644 --- a/libs/partners/anthropic/langchain_anthropic/llms.py +++ b/libs/partners/anthropic/langchain_anthropic/llms.py @@ -5,7 +5,7 @@ from __future__ import annotations import re import warnings from collections.abc import AsyncIterator, Callable, Iterator, Mapping -from typing import Any, Optional +from typing import Any import anthropic from langchain_core.callbacks import ( @@ -31,25 +31,25 @@ class _AnthropicCommon(BaseLanguageModel): max_tokens: int = Field(default=1024, alias="max_tokens_to_sample") """Denotes the number of tokens to predict per generation.""" - temperature: Optional[float] = None + temperature: float | None = None """A non-negative float that tunes the degree of randomness in generation.""" - top_k: Optional[int] = None + top_k: int | None = None """Number of most likely tokens to consider at each step.""" - top_p: Optional[float] = None + top_p: float | None = None """Total probability mass of tokens to consider at each step.""" streaming: bool = False """Whether to stream the results.""" - default_request_timeout: Optional[float] = None + default_request_timeout: float | None = None """Timeout for requests to Anthropic Completion API. Default is 600 seconds.""" max_retries: int = 2 """Number of retries allowed for requests sent to the Anthropic Completion API.""" - anthropic_api_url: Optional[str] = Field( + anthropic_api_url: str | None = Field( alias="base_url", default_factory=from_env( "ANTHROPIC_API_URL", @@ -69,9 +69,9 @@ class _AnthropicCommon(BaseLanguageModel): ) """Automatically read from env var ``ANTHROPIC_API_KEY`` if not provided.""" - HUMAN_PROMPT: Optional[str] = None - AI_PROMPT: Optional[str] = None - count_tokens: Optional[Callable[[str], int]] = None + HUMAN_PROMPT: str | None = None + AI_PROMPT: str | None = None + count_tokens: Callable[[str], int] | None = None model_kwargs: dict[str, Any] = Field(default_factory=dict) @model_validator(mode="before") @@ -120,7 +120,7 @@ class _AnthropicCommon(BaseLanguageModel): """Get the identifying parameters.""" return {**self._default_params} - def _get_anthropic_stop(self, stop: Optional[list[str]] = None) -> list[str]: + def _get_anthropic_stop(self, stop: list[str] | None = None) -> list[str]: if stop is None: stop = [] return stop @@ -190,7 +190,7 @@ class AnthropicLLM(LLM, _AnthropicCommon): def _get_ls_params( self, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> LangSmithParams: """Get standard params for tracing.""" @@ -245,8 +245,8 @@ class AnthropicLLM(LLM, _AnthropicCommon): def _call( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> str: r"""Call out to Anthropic's completion endpoint. @@ -299,8 +299,8 @@ class AnthropicLLM(LLM, _AnthropicCommon): async def _acall( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> str: """Call out to Anthropic's completion endpoint asynchronously.""" @@ -331,8 +331,8 @@ class AnthropicLLM(LLM, _AnthropicCommon): def _stream( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: r"""Call Anthropic completion_stream and return the resulting generator. @@ -378,8 +378,8 @@ class AnthropicLLM(LLM, _AnthropicCommon): async def _astream( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: r"""Call Anthropic completion_stream and return the resulting generator. diff --git a/libs/partners/anthropic/langchain_anthropic/output_parsers.py b/libs/partners/anthropic/langchain_anthropic/output_parsers.py index 5cb6f362cdb..20e60fb64cb 100644 --- a/libs/partners/anthropic/langchain_anthropic/output_parsers.py +++ b/libs/partners/anthropic/langchain_anthropic/output_parsers.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Optional, Union, cast +from typing import Any, cast from langchain_core.messages import AIMessage, ToolCall from langchain_core.messages.tool import tool_call @@ -18,7 +18,7 @@ class ToolsOutputParser(BaseGenerationOutputParser): """Whether to return only the first tool call.""" args_only: bool = False """Whether to return only the arguments of the tool calls.""" - pydantic_schemas: Optional[list[type[BaseModel]]] = None + pydantic_schemas: list[type[BaseModel]] | None = None """Pydantic schemas to parse tool calls into.""" model_config = ConfigDict( @@ -77,7 +77,7 @@ def _extract_tool_calls_from_message(message: AIMessage) -> list[ToolCall]: return extract_tool_calls(message.content) -def extract_tool_calls(content: Union[str, list[Union[str, dict]]]) -> list[ToolCall]: +def extract_tool_calls(content: str | list[str | dict]) -> list[ToolCall]: """Extract tool calls from a list of content blocks.""" if isinstance(content, list): tool_calls = [] diff --git a/libs/partners/anthropic/pyproject.toml b/libs/partners/anthropic/pyproject.toml index d080a51e38e..7eb22a60be2 100644 --- a/libs/partners/anthropic/pyproject.toml +++ b/libs/partners/anthropic/pyproject.toml @@ -69,8 +69,6 @@ ignore = [ "COM812", # Messes with the formatter "ISC001", # Messes with the formatter "PERF203", # Rarely useful - "UP007", # non-pep604-annotation-union - "UP045", # non-pep604-annotation-optional "SIM105", # Rarely useful "FIX", # TODOs "TD", # TODOs diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index 218b098cfa4..a8eb60cb405 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -6,7 +6,7 @@ import asyncio import json import os from base64 import b64encode -from typing import Literal, Optional, cast +from typing import Literal, cast import httpx import pytest @@ -40,7 +40,7 @@ def test_stream() -> None: """Test streaming tokens from Anthropic.""" llm = ChatAnthropic(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None chunks_with_input_token_counts = 0 chunks_with_output_token_counts = 0 chunks_with_model_name = 0 @@ -86,7 +86,7 @@ async def test_astream() -> None: """Test streaming tokens from Anthropic.""" llm = ChatAnthropic(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None chunks_with_input_token_counts = 0 chunks_with_output_token_counts = 0 async for token in llm.astream("I'm Pickle Rick"): @@ -810,7 +810,7 @@ def test_citations(output_version: Literal["v0", "v1"]) -> None: assert any("citations" in block for block in response.content) # Test streaming - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream(messages): full = cast("BaseMessageChunk", chunk) if full is None else full + chunk assert isinstance(full, AIMessageChunk) @@ -850,7 +850,7 @@ def test_thinking() -> None: assert isinstance(block["signature"], str) # Test streaming - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream([input_message]): full = cast("BaseMessageChunk", chunk) if full is None else full + chunk assert isinstance(full, AIMessageChunk) @@ -894,7 +894,7 @@ def test_thinking_v1() -> None: assert isinstance(signature, str) # Test streaming - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream([input_message]): full = cast(BaseMessageChunk, chunk) if full is None else full + chunk assert isinstance(full, AIMessageChunk) @@ -948,7 +948,7 @@ def test_redacted_thinking(output_version: Literal["v0", "v1"]) -> None: assert value is not None # Test streaming - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream([input_message]): full = cast("BaseMessageChunk", chunk) if full is None else full + chunk assert isinstance(full, AIMessageChunk) @@ -1108,7 +1108,7 @@ def test_web_search(output_version: Literal["v0", "v1"]) -> None: assert block_types == {"text", "server_tool_call", "server_tool_result"} # Test streaming - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm_with_tools.stream([input_message]): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk @@ -1173,7 +1173,7 @@ def test_web_fetch_v1(output_version: Literal["v0", "v1"]) -> None: assert block_types == {"text", call_key, result_key} # Test streaming - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm_with_tools.stream([input_message]): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk @@ -1394,7 +1394,7 @@ def test_web_fetch() -> None: ) # Streaming - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm_with_tools.stream([input_message]): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk @@ -1493,7 +1493,7 @@ def test_code_execution(output_version: Literal["v0", "v1"]) -> None: assert block_types == {"text", "server_tool_call", "server_tool_result"} # Test streaming - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm_with_tools.stream([input_message]): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk @@ -1562,7 +1562,7 @@ def test_remote_mcp(output_version: Literal["v0", "v1"]) -> None: assert block_types == {"text", "server_tool_call", "server_tool_result"} # Test streaming - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream([input_message]): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk @@ -1798,7 +1798,7 @@ def test_context_management() -> None: assert response.response_metadata.get("context_management") # Test streaming - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm_with_tools.stream([input_message]): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk diff --git a/libs/partners/anthropic/tests/unit_tests/_utils.py b/libs/partners/anthropic/tests/unit_tests/_utils.py index 5f0db1f9f41..683337eb45a 100644 --- a/libs/partners/anthropic/tests/unit_tests/_utils.py +++ b/libs/partners/anthropic/tests/unit_tests/_utils.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Union +from typing import Any from langchain_core.callbacks import BaseCallbackHandler from pydantic import BaseModel @@ -22,7 +22,7 @@ class BaseFakeCallbackHandler(BaseModel): ignore_chat_model_: bool = False # to allow for similar callback handlers that are not technically equal - fake_id: Union[str, None] = None + fake_id: str | None = None # add finer-grained counters for easier debugging of failing tests chain_starts: int = 0 diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index aaf53df41aa..df55402ed56 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -4,7 +4,7 @@ from __future__ import annotations import os from collections.abc import Callable -from typing import Any, Literal, Optional, cast +from typing import Any, Literal, cast from unittest.mock import MagicMock, patch import anthropic @@ -1324,10 +1324,10 @@ def test_usage_metadata_standardization() -> None: # Null input and output tokens class UsageModelNulls(BaseModel): - input_tokens: Optional[int] = None - output_tokens: Optional[int] = None - cache_read_input_tokens: Optional[int] = None - cache_creation_input_tokens: Optional[int] = None + input_tokens: int | None = None + output_tokens: int | None = None + cache_read_input_tokens: int | None = None + cache_creation_input_tokens: int | None = None usage_nulls = UsageModelNulls() result = _create_usage_metadata(usage_nulls) diff --git a/libs/partners/anthropic/uv.lock b/libs/partners/anthropic/uv.lock index bb1ace99fb4..c8fee5e9780 100644 --- a/libs/partners/anthropic/uv.lock +++ b/libs/partners/anthropic/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10.0, <4.0.0" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation == 'PyPy'", @@ -496,7 +496,7 @@ typing = [ [[package]] name = "langchain-core" -version = "1.0.0a7" +version = "1.0.0a8" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, diff --git a/libs/partners/chroma/langchain_chroma/__init__.py b/libs/partners/chroma/langchain_chroma/__init__.py index 27d97164bb9..f8e851c48b0 100644 --- a/libs/partners/chroma/langchain_chroma/__init__.py +++ b/libs/partners/chroma/langchain_chroma/__init__.py @@ -1,3 +1,5 @@ +"""LangChain integration for Chroma vector database.""" + from langchain_chroma.vectorstores import Chroma __all__ = [ diff --git a/libs/partners/chroma/langchain_chroma/vectorstores.py b/libs/partners/chroma/langchain_chroma/vectorstores.py index 597ed5a656f..036d69aab50 100644 --- a/libs/partners/chroma/langchain_chroma/vectorstores.py +++ b/libs/partners/chroma/langchain_chroma/vectorstores.py @@ -9,11 +9,10 @@ import base64 import logging import uuid from collections.abc import Callable, Iterable, Sequence +from pathlib import Path from typing import ( TYPE_CHECKING, Any, - Optional, - Union, ) import chromadb @@ -67,10 +66,10 @@ def _results_to_docs_and_vectors(results: Any) -> list[tuple[Document, np.ndarra ] -Matrix = Union[list[list[float]], list[np.ndarray], np.ndarray] +Matrix = list[list[float]], list[np.ndarray] | np.ndarray -def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: +def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: # type: ignore[valid-type] """Row-wise cosine similarity between two equal-width matrices. Raises: @@ -301,20 +300,20 @@ class Chroma(VectorStore): def __init__( self, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - embedding_function: Optional[Embeddings] = None, - persist_directory: Optional[str] = None, - host: Optional[str] = None, - port: Optional[int] = None, - headers: Optional[dict[str, str]] = None, - chroma_cloud_api_key: Optional[str] = None, - tenant: Optional[str] = None, - database: Optional[str] = None, - client_settings: Optional[chromadb.config.Settings] = None, - collection_metadata: Optional[dict] = None, - collection_configuration: Optional[CreateCollectionConfiguration] = None, - client: Optional[chromadb.ClientAPI] = None, - relevance_score_fn: Optional[Callable[[float], float]] = None, - create_collection_if_not_exists: Optional[bool] = True, # noqa: FBT001, FBT002 + embedding_function: Embeddings | None = None, + persist_directory: str | None = None, + host: str | None = None, + port: int | None = None, + headers: dict[str, str] | None = None, + chroma_cloud_api_key: str | None = None, + tenant: str | None = None, + database: str | None = None, + client_settings: chromadb.config.Settings | None = None, + collection_metadata: dict | None = None, + collection_configuration: CreateCollectionConfiguration | None = None, + client: chromadb.ClientAPI | None = None, + relevance_score_fn: Callable[[float], float] | None = None, + create_collection_if_not_exists: bool | None = True, # noqa: FBT001, FBT002 *, ssl: bool = False, ) -> None: @@ -408,7 +407,7 @@ class Chroma(VectorStore): self._client = chromadb.Client(settings=_settings) self._embedding_function = embedding_function - self._chroma_collection: Optional[chromadb.Collection] = None + self._chroma_collection: chromadb.Collection | None = None self._collection_name = collection_name self._collection_metadata = collection_metadata self._collection_configuration = collection_configuration @@ -441,20 +440,20 @@ class Chroma(VectorStore): return self._chroma_collection @property - def embeddings(self) -> Optional[Embeddings]: + def embeddings(self) -> Embeddings | None: """Access the query embedding object.""" return self._embedding_function @xor_args(("query_texts", "query_embeddings")) def __query_collection( self, - query_texts: Optional[list[str]] = None, - query_embeddings: Optional[list[list[float]]] = None, + query_texts: list[str] | None = None, + query_embeddings: list[list[float]] | None = None, n_results: int = 4, - where: Optional[dict[str, str]] = None, - where_document: Optional[dict[str, str]] = None, + where: dict[str, str] | None = None, + where_document: dict[str, str] | None = None, **kwargs: Any, - ) -> Union[list[Document], chromadb.QueryResult]: + ) -> list[Document] | chromadb.QueryResult: """Query the chroma collection. Args: @@ -485,7 +484,7 @@ class Chroma(VectorStore): @staticmethod def encode_image(uri: str) -> str: """Get base64 string from image URI.""" - with open(uri, "rb") as image_file: + with Path(uri).open("rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") def fork(self, new_name: str) -> Chroma: @@ -508,8 +507,8 @@ class Chroma(VectorStore): def add_images( self, uris: list[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[list[str]] = None, + metadatas: list[dict] | None = None, + ids: list[str] | None = None, ) -> list[str]: """Run more images through the embeddings and add to the vectorstore. @@ -596,8 +595,8 @@ class Chroma(VectorStore): def add_texts( self, texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[list[str]] = None, + metadatas: list[dict] | None = None, + ids: list[str] | None = None, **kwargs: Any, ) -> list[str]: """Run more texts through the embeddings and add to the vectorstore. @@ -684,7 +683,7 @@ class Chroma(VectorStore): self, query: str, k: int = DEFAULT_K, - filter: Optional[dict[str, str]] = None, # noqa: A002 + filter: dict[str, str] | None = None, # noqa: A002 **kwargs: Any, ) -> list[Document]: """Run similarity search with Chroma. @@ -710,8 +709,8 @@ class Chroma(VectorStore): self, embedding: list[float], k: int = DEFAULT_K, - filter: Optional[dict[str, str]] = None, # noqa: A002 - where_document: Optional[dict[str, str]] = None, + filter: dict[str, str] | None = None, # noqa: A002 + where_document: dict[str, str] | None = None, **kwargs: Any, ) -> list[Document]: """Return docs most similar to embedding vector. @@ -740,8 +739,8 @@ class Chroma(VectorStore): self, embedding: list[float], k: int = DEFAULT_K, - filter: Optional[dict[str, str]] = None, # noqa: A002 - where_document: Optional[dict[str, str]] = None, + filter: dict[str, str] | None = None, # noqa: A002 + where_document: dict[str, str] | None = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs most similar to embedding vector and similarity score. @@ -771,8 +770,8 @@ class Chroma(VectorStore): self, query: str, k: int = DEFAULT_K, - filter: Optional[dict[str, str]] = None, # noqa: A002 - where_document: Optional[dict[str, str]] = None, + filter: dict[str, str] | None = None, # noqa: A002 + where_document: dict[str, str] | None = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Run similarity search with Chroma with distance. @@ -813,8 +812,8 @@ class Chroma(VectorStore): self, query: str, k: int = DEFAULT_K, - filter: Optional[dict[str, str]] = None, # noqa: A002 - where_document: Optional[dict[str, str]] = None, + filter: dict[str, str] | None = None, # noqa: A002 + where_document: dict[str, str] | None = None, **kwargs: Any, ) -> list[tuple[Document, np.ndarray]]: """Run similarity search with Chroma with vectors. @@ -873,12 +872,10 @@ class Chroma(VectorStore): return self.override_relevance_score_fn hnsw_config = self._collection.configuration.get("hnsw") - hnsw_distance: Optional[str] = hnsw_config.get("space") if hnsw_config else None + hnsw_distance: str | None = hnsw_config.get("space") if hnsw_config else None spann_config = self._collection.configuration.get("spann") - spann_distance: Optional[str] = ( - spann_config.get("space") if spann_config else None - ) + spann_distance: str | None = spann_config.get("space") if spann_config else None distance = hnsw_distance or spann_distance @@ -901,7 +898,7 @@ class Chroma(VectorStore): self, uri: str, k: int = DEFAULT_K, - filter: Optional[dict[str, str]] = None, # noqa: A002 + filter: dict[str, str] | None = None, # noqa: A002 **kwargs: Any, ) -> list[Document]: """Search for similar images based on the given image URI. @@ -943,7 +940,7 @@ class Chroma(VectorStore): self, uri: str, k: int = DEFAULT_K, - filter: Optional[dict[str, str]] = None, # noqa: A002 + filter: dict[str, str] | None = None, # noqa: A002 **kwargs: Any, ) -> list[tuple[Document, float]]: """Search for similar images based on the given image URI. @@ -987,8 +984,8 @@ class Chroma(VectorStore): k: int = DEFAULT_K, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[dict[str, str]] = None, # noqa: A002 - where_document: Optional[dict[str, str]] = None, + filter: dict[str, str] | None = None, # noqa: A002 + where_document: dict[str, str] | None = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance. @@ -1038,8 +1035,8 @@ class Chroma(VectorStore): k: int = DEFAULT_K, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[dict[str, str]] = None, # noqa: A002 - where_document: Optional[dict[str, str]] = None, + filter: dict[str, str] | None = None, # noqa: A002 + where_document: dict[str, str] | None = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance. @@ -1097,12 +1094,12 @@ class Chroma(VectorStore): def get( self, - ids: Optional[Union[str, list[str]]] = None, - where: Optional[Where] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, - where_document: Optional[WhereDocument] = None, - include: Optional[list[str]] = None, + ids: str | list[str] | None = None, + where: Where | None = None, + limit: int | None = None, + offset: int | None = None, + where_document: WhereDocument | None = None, + include: list[str] | None = None, ) -> dict[str, Any]: """Gets the collection. @@ -1233,21 +1230,21 @@ class Chroma(VectorStore): def from_texts( cls: type[Chroma], texts: list[str], - embedding: Optional[Embeddings] = None, - metadatas: Optional[list[dict]] = None, - ids: Optional[list[str]] = None, + embedding: Embeddings | None = None, + metadatas: list[dict] | None = None, + ids: list[str] | None = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - persist_directory: Optional[str] = None, - host: Optional[str] = None, - port: Optional[int] = None, - headers: Optional[dict[str, str]] = None, - chroma_cloud_api_key: Optional[str] = None, - tenant: Optional[str] = None, - database: Optional[str] = None, - client_settings: Optional[chromadb.config.Settings] = None, - client: Optional[chromadb.ClientAPI] = None, - collection_metadata: Optional[dict] = None, - collection_configuration: Optional[CreateCollectionConfiguration] = None, + persist_directory: str | None = None, + host: str | None = None, + port: int | None = None, + headers: dict[str, str] | None = None, + chroma_cloud_api_key: str | None = None, + tenant: str | None = None, + database: str | None = None, + client_settings: chromadb.config.Settings | None = None, + client: chromadb.ClientAPI | None = None, + collection_metadata: dict | None = None, + collection_configuration: CreateCollectionConfiguration | None = None, *, ssl: bool = False, **kwargs: Any, @@ -1335,20 +1332,20 @@ class Chroma(VectorStore): def from_documents( cls: type[Chroma], documents: list[Document], - embedding: Optional[Embeddings] = None, - ids: Optional[list[str]] = None, + embedding: Embeddings | None = None, + ids: list[str] | None = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - persist_directory: Optional[str] = None, - host: Optional[str] = None, - port: Optional[int] = None, - headers: Optional[dict[str, str]] = None, - chroma_cloud_api_key: Optional[str] = None, - tenant: Optional[str] = None, - database: Optional[str] = None, - client_settings: Optional[chromadb.config.Settings] = None, - client: Optional[chromadb.ClientAPI] = None, # Add this line - collection_metadata: Optional[dict] = None, - collection_configuration: Optional[CreateCollectionConfiguration] = None, + persist_directory: str | None = None, + host: str | None = None, + port: int | None = None, + headers: dict[str, str] | None = None, + chroma_cloud_api_key: str | None = None, + tenant: str | None = None, + database: str | None = None, + client_settings: chromadb.config.Settings | None = None, + client: chromadb.ClientAPI | None = None, # Add this line + collection_metadata: dict | None = None, + collection_configuration: CreateCollectionConfiguration | None = None, *, ssl: bool = False, **kwargs: Any, @@ -1410,7 +1407,7 @@ class Chroma(VectorStore): **kwargs, ) - def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> None: + def delete(self, ids: list[str] | None = None, **kwargs: Any) -> None: """Delete by vector IDs. Args: diff --git a/libs/partners/chroma/pyproject.toml b/libs/partners/chroma/pyproject.toml index c6ea6dd15a1..4fa46860ef5 100644 --- a/libs/partners/chroma/pyproject.toml +++ b/libs/partners/chroma/pyproject.toml @@ -66,53 +66,27 @@ disallow_untyped_defs = true docstring-code-format = true [tool.ruff.lint] -select = [ - "A", # flake8-builtins - "B", # flake8-bugbear - "ASYNC", # flake8-async - "C4", # flake8-comprehensions - "COM", # flake8-commas - "D", # pydocstyle - "E", # pycodestyle error - "EM", # flake8-errmsg - "F", # pyflakes - "FA", # flake8-future-annotations - "FBT", # flake8-boolean-trap - "FLY", # flake8-flynt - "I", # isort - "ICN", # flake8-import-conventions - "INT", # flake8-gettext - "ISC", # isort-comprehensions - "PGH", # pygrep-hooks - "PIE", # flake8-pie - "PERF", # flake8-perf - "PYI", # flake8-pyi - "Q", # flake8-quotes - "RET", # flake8-return - "RSE", # flake8-rst-docstrings - "RUF", # ruff - "S", # flake8-bandit - "SLF", # flake8-self - "SLOT", # flake8-slots - "SIM", # flake8-simplify - "T10", # flake8-debugger - "T20", # flake8-print - "TID", # flake8-tidy-imports - "UP", # pyupgrade - "W", # pycodestyle warning - "YTT", # flake8-2020 -] +select = [ "ALL" ] ignore = [ - "D100", # pydocstyle: Missing docstring in public module - "D101", # pydocstyle: Missing docstring in public class - "D102", # pydocstyle: Missing docstring in public method - "D103", # pydocstyle: Missing docstring in public function - "D104", # pydocstyle: Missing docstring in public package - "D105", # pydocstyle: Missing docstring in magic method - "D107", # pydocstyle: Missing docstring in __init__ - "UP007", # pyupgrade: non-pep604-annotation-union - "UP045", # pyupgrade: non-pep604-annotation-optional "COM812", # Messes with the formatter + "PLC0415", # Import top of file + "FIX002", # TODO + "TD002", # TODO + "TD003", # TODO + "PLR0912", # Too many branches + "PLR0913", # Too many arguments + "C901", # Too complex + + # TODO + "ANN204", + "ANN401", + "TRY201", + "ARG002", + "N803", + "TC002", + "TC003", + "TRY300", + "N806", ] unfixable = ["B028"] # People should intentionally tune the stacklevel @@ -136,7 +110,13 @@ ignore-var-parameters = true # ignore missing documentation for *args and **kwa [tool.ruff.lint.extend-per-file-ignores] "tests/**/*.py" = [ - "S101", # Tests need assertions - "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes - "SLF001", # Private member access in tests + "S101", # Tests need assertions + "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes + "SLF001", # Private member access in tests + "PLR2004", # Comparison to magic number + "PT011", # Exception too broad + "BLE001", # Blind except +] +"scripts/*.py" = [ + "INP001", # Not a package ] diff --git a/libs/partners/chroma/scripts/check_imports.py b/libs/partners/chroma/scripts/check_imports.py index b850913e53e..8a1dfda747f 100644 --- a/libs/partners/chroma/scripts/check_imports.py +++ b/libs/partners/chroma/scripts/check_imports.py @@ -10,7 +10,7 @@ if __name__ == "__main__": for file in files: try: SourceFileLoader("x", file).load_module() - except Exception: # noqa: PERF203 + except Exception: # noqa: PERF203, BLE001 has_failure = True print(file) # noqa: T201 traceback.print_exc() diff --git a/libs/partners/chroma/tests/integration_tests/test_vectorstores.py b/libs/partners/chroma/tests/integration_tests/test_vectorstores.py index 99943c2f5a3..3c4b0ae75ff 100644 --- a/libs/partners/chroma/tests/integration_tests/test_vectorstores.py +++ b/libs/partners/chroma/tests/integration_tests/test_vectorstores.py @@ -1,9 +1,8 @@ """Test Chroma functionality.""" -import os.path import tempfile import uuid -from collections.abc import Generator +from pathlib import Path from typing import ( cast, ) @@ -29,15 +28,14 @@ class MyEmbeddingFunction: self.fak = fak def __call__(self, input_: Embeddable) -> list[list[float]]: - texts = cast(list[str], input_) + texts = cast("list[str]", input_) return self.fak.embed_documents(texts=texts) -@pytest.fixture() -def client() -> Generator[chromadb.ClientAPI, None, None]: +@pytest.fixture +def client() -> chromadb.ClientAPI: SharedSystemClient.clear_system_cache() - client = chromadb.Client(chromadb.config.Settings()) - yield client + return chromadb.Client(chromadb.config.Settings()) def test_chroma() -> None: @@ -297,7 +295,7 @@ def test_chroma_with_persistence() -> None: output = docsearch.similarity_search("foo", k=1) assert output == [Document(page_content="foo", id="id_0")] - assert os.path.exists(chroma_persist_dir) + assert Path(chroma_persist_dir).exists() # Get a new VectorStore from the persisted directory docsearch = Chroma( @@ -344,7 +342,7 @@ def test_chroma_with_persistence_with_client_settings() -> None: output = docsearch.similarity_search("foo", k=1) assert output == [Document(page_content="foo", id="id_0")] - assert os.path.exists(chroma_persist_dir) + assert Path(chroma_persist_dir).exists() # Get a new VectorStore from the persisted directory docsearch = Chroma( diff --git a/libs/partners/chroma/tests/unit_tests/test_standard.py b/libs/partners/chroma/tests/unit_tests/test_standard.py index f580e5725a4..bfb14ca5be3 100644 --- a/libs/partners/chroma/tests/unit_tests/test_standard.py +++ b/libs/partners/chroma/tests/unit_tests/test_standard.py @@ -9,7 +9,7 @@ from langchain_chroma import Chroma class TestChromaStandard(VectorStoreIntegrationTests): - @pytest.fixture() + @pytest.fixture def vectorstore(self) -> Generator[VectorStore, None, None]: # type: ignore[override] """Get an empty vectorstore for unit tests.""" store = Chroma(embedding_function=self.get_embeddings()) diff --git a/libs/partners/chroma/uv.lock b/libs/partners/chroma/uv.lock index 756aeb32acd..1d128ef1480 100644 --- a/libs/partners/chroma/uv.lock +++ b/libs/partners/chroma/uv.lock @@ -839,7 +839,7 @@ typing = [ [[package]] name = "langchain-core" -version = "1.0.0a6" +version = "1.0.0a8" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, diff --git a/libs/partners/deepseek/langchain_deepseek/__init__.py b/libs/partners/deepseek/langchain_deepseek/__init__.py index 519ba6281c8..07fec728057 100644 --- a/libs/partners/deepseek/langchain_deepseek/__init__.py +++ b/libs/partners/deepseek/langchain_deepseek/__init__.py @@ -1,3 +1,5 @@ +"""LangChain DeepSeek integration.""" + from importlib import metadata from langchain_deepseek.chat_models import ChatDeepSeek diff --git a/libs/partners/deepseek/langchain_deepseek/chat_models.py b/libs/partners/deepseek/langchain_deepseek/chat_models.py index 1df7a5a8410..daa42659267 100644 --- a/libs/partners/deepseek/langchain_deepseek/chat_models.py +++ b/libs/partners/deepseek/langchain_deepseek/chat_models.py @@ -5,7 +5,7 @@ from __future__ import annotations import json from collections.abc import Iterator from json import JSONDecodeError -from typing import Any, Literal, Optional, TypeVar, Union +from typing import Any, Literal, TypeAlias import openai from langchain_core.callbacks import ( @@ -22,16 +22,15 @@ from typing_extensions import Self DEFAULT_API_BASE = "https://api.deepseek.com/v1" -_BM = TypeVar("_BM", bound=BaseModel) -_DictOrPydanticClass = Union[dict[str, Any], type[_BM], type] -_DictOrPydantic = Union[dict, _BM] +_DictOrPydanticClass: TypeAlias = dict[str, Any] | type[BaseModel] +_DictOrPydantic: TypeAlias = dict[str, Any] | BaseModel class ChatDeepSeek(BaseChatOpenAI): """DeepSeek chat model integration to access models hosted in DeepSeek's API. Setup: - Install ``langchain-deepseek`` and set environment variable ``DEEPSEEK_API_KEY``. + Install `langchain-deepseek` and set environment variable ``DEEPSEEK_API_KEY``. .. code-block:: bash @@ -170,7 +169,7 @@ class ChatDeepSeek(BaseChatOpenAI): model_name: str = Field(alias="model") """The name of the model""" - api_key: Optional[SecretStr] = Field( + api_key: SecretStr | None = Field( default_factory=secret_from_env("DEEPSEEK_API_KEY", default=None), ) """DeepSeek API key""" @@ -193,7 +192,7 @@ class ChatDeepSeek(BaseChatOpenAI): def _get_ls_params( self, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> LangSmithParams: ls_params = super()._get_ls_params(stop=stop, **kwargs) @@ -202,6 +201,7 @@ class ChatDeepSeek(BaseChatOpenAI): @model_validator(mode="after") def validate_environment(self) -> Self: + """Validate necessary environment vars and client params.""" if self.api_base == DEFAULT_API_BASE and not ( self.api_key and self.api_key.get_secret_value() ): @@ -237,19 +237,30 @@ class ChatDeepSeek(BaseChatOpenAI): self, input_: LanguageModelInput, *, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> dict: payload = super()._get_request_payload(input_, stop=stop, **kwargs) for message in payload["messages"]: if message["role"] == "tool" and isinstance(message["content"], list): message["content"] = json.dumps(message["content"]) + elif message["role"] == "assistant" and isinstance( + message["content"], list + ): + # DeepSeek API expects assistant content to be a string, not a list. + # Extract text blocks and join them, or use empty string if none exist. + text_parts = [ + block.get("text", "") + for block in message["content"] + if isinstance(block, dict) and block.get("type") == "text" + ] + message["content"] = "".join(text_parts) if text_parts else "" return payload def _create_chat_result( self, - response: Union[dict, openai.BaseModel], - generation_info: Optional[dict] = None, + response: dict | openai.BaseModel, + generation_info: dict | None = None, ) -> ChatResult: rtn = super()._create_chat_result(response, generation_info) @@ -277,8 +288,8 @@ class ChatDeepSeek(BaseChatOpenAI): self, chunk: dict, default_chunk_class: type, - base_generation_info: Optional[dict], - ) -> Optional[ChatGenerationChunk]: + base_generation_info: dict | None, + ) -> ChatGenerationChunk | None: generation_chunk = super()._convert_chunk_to_generation_chunk( chunk, default_chunk_class, @@ -304,8 +315,8 @@ class ChatDeepSeek(BaseChatOpenAI): def _stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: try: @@ -329,8 +340,8 @@ class ChatDeepSeek(BaseChatOpenAI): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: try: @@ -353,7 +364,7 @@ class ChatDeepSeek(BaseChatOpenAI): def with_structured_output( self, - schema: Optional[_DictOrPydanticClass] = None, + schema: _DictOrPydanticClass | None = None, *, method: Literal[ "function_calling", @@ -361,7 +372,7 @@ class ChatDeepSeek(BaseChatOpenAI): "json_schema", ] = "function_calling", include_raw: bool = False, - strict: Optional[bool] = None, + strict: bool | None = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, _DictOrPydantic]: """Model wrapper that returns outputs formatted to match the given schema. diff --git a/libs/partners/deepseek/pyproject.toml b/libs/partners/deepseek/pyproject.toml index bc1791d1c06..759ceb4e8a3 100644 --- a/libs/partners/deepseek/pyproject.toml +++ b/libs/partners/deepseek/pyproject.toml @@ -53,52 +53,16 @@ docstring-code-format = true docstring-code-line-length = 100 [tool.ruff.lint] -select = [ - "A", # flake8-builtins - "ASYNC", # flake8-async - "C4", # flake8-comprehensions - "COM", # flake8-commas - "D", # pydocstyle - "E", # pycodestyle error - "EM", # flake8-errmsg - "F", # pyflakes - "FA", # flake8-future-annotations - "FBT", # flake8-boolean-trap - "FLY", # flake8-flynt - "I", # isort - "ICN", # flake8-import-conventions - "INT", # flake8-gettext - "ISC", # isort-comprehensions - "PGH", # pygrep-hooks - "PIE", # flake8-pie - "PERF", # flake8-perf - "PYI", # flake8-pyi - "Q", # flake8-quotes - "RET", # flake8-return - "RSE", # flake8-rst-docstrings - "RUF", # ruff - "S", # flake8-bandit - "SLF", # flake8-self - "SLOT", # flake8-slots - "SIM", # flake8-simplify - "T10", # flake8-debugger - "T20", # flake8-print - "TID", # flake8-tidy-imports - "UP", # pyupgrade - "W", # pycodestyle warning - "YTT", # flake8-2020 -] +select = [ "ALL" ] ignore = [ "COM812", # Conflicts with formatter - "D100", # Missing docstring in public module - "D101", # Missing docstring in public class - "D102", # Missing docstring in public method - "D103", # Missing docstring in public function - "D104", # Missing docstring in public package - "D105", # Missing docstring in magic method - "D107", # Missing docstring in __init__ - "UP007", # pyupgrade: non-pep604-annotation-union - "UP045", # pyupgrade: non-pep604-annotation-optional + "PLR0913", # Too many arguments + + # TODO + "ANN401", + "TC002", + "TC003", + "ANN401", ] unfixable = ["B028"] # People should intentionally tune the stacklevel @@ -121,4 +85,10 @@ ignore-var-parameters = true # ignore missing documentation for *args and **kwa "S101", # Tests need assertions "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes "SLF001", # Private member access + + # TODO + "ARG002", # Unused method argument: +] +"scripts/*.py" = [ + "INP001", # Not a package ] diff --git a/libs/partners/deepseek/scripts/check_imports.py b/libs/partners/deepseek/scripts/check_imports.py index 724644fce53..a703229885e 100644 --- a/libs/partners/deepseek/scripts/check_imports.py +++ b/libs/partners/deepseek/scripts/check_imports.py @@ -1,3 +1,5 @@ +"""Script to check imports of given Python files.""" + import sys import traceback from importlib.machinery import SourceFileLoader @@ -8,7 +10,7 @@ if __name__ == "__main__": for file in files: try: SourceFileLoader("x", file).load_module() - except Exception: # noqa: PERF203 + except Exception: # noqa: PERF203, BLE001 has_failure = True print(file) # noqa: T201 traceback.print_exc() diff --git a/libs/partners/deepseek/tests/__init__.py b/libs/partners/deepseek/tests/__init__.py index e69de29bb2d..b83fa24ef75 100644 --- a/libs/partners/deepseek/tests/__init__.py +++ b/libs/partners/deepseek/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for `langchain_deepseek` package.""" diff --git a/libs/partners/deepseek/tests/integration_tests/__init__.py b/libs/partners/deepseek/tests/integration_tests/__init__.py index e69de29bb2d..6edd30dfe91 100644 --- a/libs/partners/deepseek/tests/integration_tests/__init__.py +++ b/libs/partners/deepseek/tests/integration_tests/__init__.py @@ -0,0 +1 @@ +"""Integration tests for `langchain_deepseek` package.""" diff --git a/libs/partners/deepseek/tests/integration_tests/test_chat_models.py b/libs/partners/deepseek/tests/integration_tests/test_chat_models.py index b0f84ed1d8c..817162901fc 100644 --- a/libs/partners/deepseek/tests/integration_tests/test_chat_models.py +++ b/libs/partners/deepseek/tests/integration_tests/test_chat_models.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import Optional - import pytest from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessageChunk, BaseMessageChunk @@ -14,13 +12,16 @@ from langchain_deepseek.chat_models import ChatDeepSeek class TestChatDeepSeek(ChatModelIntegrationTests): + """Test `ChatDeepSeek` chat model.""" + @property def chat_model_class(self) -> type[ChatDeepSeek]: + """Return class of chat model being tested.""" return ChatDeepSeek @property def chat_model_params(self) -> dict: - # These should be parameters used to initialize your integration for testing + """Parameters to create chat model instance for testing.""" return { "model": "deepseek-chat", "temperature": 0, @@ -37,6 +38,7 @@ class TestChatDeepSeek(ChatModelIntegrationTests): model: BaseChatModel, my_adder_tool: BaseTool, ) -> None: + """Override test for tool message histories with list content.""" super().test_tool_message_histories_list_content(model, my_adder_tool) @@ -52,8 +54,9 @@ def test_reasoning_content() -> None: @pytest.mark.xfail(reason="Takes > 30s to run.") def test_reasoning_content_streaming() -> None: + """Test reasoning content with streaming.""" chat_model = ChatDeepSeek(model="deepseek-reasoner") - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in chat_model.stream("What is 3^3?"): full = chunk if full is None else full + chunk assert isinstance(full, AIMessageChunk) diff --git a/libs/partners/deepseek/tests/integration_tests/test_compile.py b/libs/partners/deepseek/tests/integration_tests/test_compile.py index f315e45f521..1fa3ee28f44 100644 --- a/libs/partners/deepseek/tests/integration_tests/test_compile.py +++ b/libs/partners/deepseek/tests/integration_tests/test_compile.py @@ -1,3 +1,5 @@ +"""Test compilation of integration tests.""" + import pytest diff --git a/libs/partners/deepseek/tests/unit_tests/__init__.py b/libs/partners/deepseek/tests/unit_tests/__init__.py index e69de29bb2d..57d3857fa70 100644 --- a/libs/partners/deepseek/tests/unit_tests/__init__.py +++ b/libs/partners/deepseek/tests/unit_tests/__init__.py @@ -0,0 +1 @@ +"""Unit tests for `langchain_deepseek` package.""" diff --git a/libs/partners/deepseek/tests/unit_tests/test_chat_models.py b/libs/partners/deepseek/tests/unit_tests/test_chat_models.py index 2bb61a0fa70..d4bde140007 100644 --- a/libs/partners/deepseek/tests/unit_tests/test_chat_models.py +++ b/libs/partners/deepseek/tests/unit_tests/test_chat_models.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Literal, Union +from typing import Any, Literal from unittest.mock import MagicMock from langchain_core.messages import AIMessageChunk, ToolMessage @@ -15,13 +15,15 @@ from langchain_deepseek.chat_models import ChatDeepSeek class MockOpenAIResponse(BaseModel): + """Mock OpenAI response model.""" + choices: list error: None = None def model_dump( # type: ignore[override] self, *, - mode: Union[Literal["json", "python"], str] = "python", # noqa: PYI051 + mode: Literal["json", "python"] | str = "python", # noqa: PYI051 include: Any = None, exclude: Any = None, by_alias: bool = False, @@ -29,10 +31,11 @@ class MockOpenAIResponse(BaseModel): exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, - warnings: Union[Literal["none", "warn", "error"], bool] = True, - context: Union[dict[str, Any], None] = None, + warnings: Literal["none", "warn", "error"] | bool = True, + context: dict[str, Any] | None = None, serialize_as_any: bool = False, ) -> dict[str, Any]: + """Convert to dictionary, ensuring reasoning_content is included.""" choices_list = [] for choice in self.choices: if isinstance(choice.message, ChatCompletionMessage): @@ -58,12 +61,16 @@ class MockOpenAIResponse(BaseModel): class TestChatDeepSeekUnit(ChatModelUnitTests): + """Unit tests for `ChatDeepSeek` chat model.""" + @property def chat_model_class(self) -> type[ChatDeepSeek]: + """Chat model class being tested.""" return ChatDeepSeek @property def init_from_env_params(self) -> tuple[dict, dict, dict]: + """Parameters to initialize from environment variables.""" return ( { "DEEPSEEK_API_KEY": "api_key", @@ -80,7 +87,7 @@ class TestChatDeepSeekUnit(ChatModelUnitTests): @property def chat_model_params(self) -> dict: - # These should be parameters used to initialize your integration for testing + """Parameters to create chat model instance for testing.""" return { "model": "deepseek-chat", "api_key": "api_key", diff --git a/libs/partners/deepseek/uv.lock b/libs/partners/deepseek/uv.lock index c21fbc74721..8f97de18e14 100644 --- a/libs/partners/deepseek/uv.lock +++ b/libs/partners/deepseek/uv.lock @@ -212,7 +212,7 @@ name = "exceptiongroup" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } wheels = [ @@ -370,7 +370,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "1.0.0a6" +version = "1.0.0a8" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, @@ -475,7 +475,7 @@ typing = [{ name = "mypy", specifier = ">=1.10.0,<2.0.0" }] [[package]] name = "langchain-openai" -version = "1.0.0a3" +version = "1.0.0a4" source = { editable = "../openai" } dependencies = [ { name = "langchain-core" }, diff --git a/libs/partners/exa/langchain_exa/__init__.py b/libs/partners/exa/langchain_exa/__init__.py index 552a836a5e6..0f9427a7b43 100644 --- a/libs/partners/exa/langchain_exa/__init__.py +++ b/libs/partners/exa/langchain_exa/__init__.py @@ -1,3 +1,5 @@ +"""LangChain integration for Exa.""" + from exa_py.api import ( HighlightsContentsOptions, TextContentsOptions, diff --git a/libs/partners/exa/langchain_exa/retrievers.py b/libs/partners/exa/langchain_exa/retrievers.py index cce2f68f186..10d84396d66 100644 --- a/libs/partners/exa/langchain_exa/retrievers.py +++ b/libs/partners/exa/langchain_exa/retrievers.py @@ -1,6 +1,8 @@ +"""Retriever using Exa Search API.""" + from __future__ import annotations -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from exa_py import Exa # type: ignore[untyped-import] from exa_py.api import ( @@ -39,39 +41,37 @@ class ExaSearchRetriever(BaseRetriever): k: int = 10 # num_results """The number of search results to return (1 to 100).""" - include_domains: Optional[list[str]] = None + include_domains: list[str] | None = None """A list of domains to include in the search.""" - exclude_domains: Optional[list[str]] = None + exclude_domains: list[str] | None = None """A list of domains to exclude from the search.""" - start_crawl_date: Optional[str] = None + start_crawl_date: str | None = None """The start date for the crawl (in YYYY-MM-DD format).""" - end_crawl_date: Optional[str] = None + end_crawl_date: str | None = None """The end date for the crawl (in YYYY-MM-DD format).""" - start_published_date: Optional[str] = None + start_published_date: str | None = None """The start date for when the document was published (in YYYY-MM-DD format).""" - end_published_date: Optional[str] = None + end_published_date: str | None = None """The end date for when the document was published (in YYYY-MM-DD format).""" - use_autoprompt: Optional[bool] = None + use_autoprompt: bool | None = None """Whether to use autoprompt for the search.""" type: str = "neural" """The type of search, 'keyword', 'neural', or 'auto'. Default: neural""" - highlights: Optional[Union[HighlightsContentsOptions, bool]] = None + highlights: HighlightsContentsOptions | bool | None = None """Whether to set the page content to the highlights of the results.""" - text_contents_options: Union[TextContentsOptions, dict[str, Any], Literal[True]] = ( - True - ) + text_contents_options: TextContentsOptions | dict[str, Any] | Literal[True] = True """How to set the page content of the results. Can be True or a dict with options like max_characters.""" - livecrawl: Optional[Literal["always", "fallback", "never"]] = None + livecrawl: Literal["always", "fallback", "never"] | None = None """Option to crawl live webpages if content is not in the index. Options: "always", "fallback", "never".""" - summary: Optional[Union[bool, dict[str, str]]] = None + summary: bool | dict[str, str] | None = None """Whether to include a summary of the content. Can be a boolean or a dict with a custom query.""" client: Exa = Field(default=None) exa_api_key: SecretStr = Field(default=None) - exa_base_url: Optional[str] = None + exa_base_url: str | None = None @model_validator(mode="before") @classmethod diff --git a/libs/partners/exa/langchain_exa/tools.py b/libs/partners/exa/langchain_exa/tools.py index 4e706f2abc7..0e2cdab2d3d 100644 --- a/libs/partners/exa/langchain_exa/tools.py +++ b/libs/partners/exa/langchain_exa/tools.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from exa_py import Exa # type: ignore[untyped-import] from exa_py.api import ( @@ -103,22 +103,23 @@ class ExaSearchResults(BaseTool): # type: ignore[override] self, query: str, num_results: int = 10, - text_contents_options: Optional[ # noqa: FBT001 - Union[TextContentsOptions, dict[str, Any], bool] - ] = None, - highlights: Optional[Union[HighlightsContentsOptions, bool]] = None, # noqa: FBT001 - include_domains: Optional[list[str]] = None, - exclude_domains: Optional[list[str]] = None, - start_crawl_date: Optional[str] = None, - end_crawl_date: Optional[str] = None, - start_published_date: Optional[str] = None, - end_published_date: Optional[str] = None, - use_autoprompt: Optional[bool] = None, # noqa: FBT001 - livecrawl: Optional[Literal["always", "fallback", "never"]] = None, - summary: Optional[Union[bool, dict[str, str]]] = None, # noqa: FBT001 - type: Optional[Literal["neural", "keyword", "auto"]] = None, # noqa: A002 - run_manager: Optional[CallbackManagerForToolRun] = None, - ) -> Union[list[dict], str]: + text_contents_options: TextContentsOptions # noqa: FBT001 + | dict[str, Any] + | bool + | None = None, + highlights: HighlightsContentsOptions | bool | None = None, # noqa: FBT001 + include_domains: list[str] | None = None, + exclude_domains: list[str] | None = None, + start_crawl_date: str | None = None, + end_crawl_date: str | None = None, + start_published_date: str | None = None, + end_published_date: str | None = None, + use_autoprompt: bool | None = None, # noqa: FBT001 + livecrawl: Literal["always", "fallback", "never"] | None = None, + summary: bool | dict[str, str] | None = None, # noqa: FBT001 + type: Literal["neural", "keyword", "auto"] | None = None, # noqa: A002 + run_manager: CallbackManagerForToolRun | None = None, + ) -> list[dict] | str: # TODO: rename `type` to something else, as it is a reserved keyword """Use the tool. @@ -172,7 +173,7 @@ class ExaFindSimilarResults(BaseTool): # type: ignore[override] ) client: Exa = Field(default=None) exa_api_key: SecretStr = Field(default=None) - exa_base_url: Optional[str] = None + exa_base_url: str | None = None @model_validator(mode="before") @classmethod @@ -184,22 +185,23 @@ class ExaFindSimilarResults(BaseTool): # type: ignore[override] self, url: str, num_results: int = 10, - text_contents_options: Optional[ # noqa: FBT001 - Union[TextContentsOptions, dict[str, Any], bool] - ] = None, - highlights: Optional[Union[HighlightsContentsOptions, bool]] = None, # noqa: FBT001 - include_domains: Optional[list[str]] = None, - exclude_domains: Optional[list[str]] = None, - start_crawl_date: Optional[str] = None, - end_crawl_date: Optional[str] = None, - start_published_date: Optional[str] = None, - end_published_date: Optional[str] = None, - exclude_source_domain: Optional[bool] = None, # noqa: FBT001 - category: Optional[str] = None, - livecrawl: Optional[Literal["always", "fallback", "never"]] = None, - summary: Optional[Union[bool, dict[str, str]]] = None, # noqa: FBT001 - run_manager: Optional[CallbackManagerForToolRun] = None, - ) -> Union[list[dict], str]: + text_contents_options: TextContentsOptions # noqa: FBT001 + | dict[str, Any] + | bool + | None = None, + highlights: HighlightsContentsOptions | bool | None = None, # noqa: FBT001 + include_domains: list[str] | None = None, + exclude_domains: list[str] | None = None, + start_crawl_date: str | None = None, + end_crawl_date: str | None = None, + start_published_date: str | None = None, + end_published_date: str | None = None, + exclude_source_domain: bool | None = None, # noqa: FBT001 + category: str | None = None, + livecrawl: Literal["always", "fallback", "never"] | None = None, + summary: bool | dict[str, str] | None = None, # noqa: FBT001 + run_manager: CallbackManagerForToolRun | None = None, + ) -> list[dict] | str: """Use the tool. Args: diff --git a/libs/partners/exa/pyproject.toml b/libs/partners/exa/pyproject.toml index 7db8cbd8f17..1e44377b5f9 100644 --- a/libs/partners/exa/pyproject.toml +++ b/libs/partners/exa/pyproject.toml @@ -93,21 +93,12 @@ select = [ "YTT", # flake8-2020 ] ignore = [ - "D100", # Missing docstring in public module - "D101", # Missing docstring in public class - "D102", # Missing docstring in public method - "D103", # Missing docstring in public function - "D104", # Missing docstring in public package - "D105", # Missing docstring in magic method - "D107", # Missing docstring in __init__ "COM812", # Messes with the formatter "ISC001", # Messes with the formatter "PERF203", # Rarely useful "S112", # Rarely useful "RUF012", # Doesn't play well with Pydantic "SLF001", # Private member access - "UP007", # pyupgrade: non-pep604-annotation-union - "UP045", # pyupgrade: non-pep604-annotation-optional ] unfixable = ["B028"] # People should intentionally tune the stacklevel @@ -131,3 +122,6 @@ ignore-var-parameters = true # ignore missing documentation for *args and **kwa "S101", # Tests need assertions "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes ] +"scripts/*.py" = [ + "INP001", # Not a package +] diff --git a/libs/partners/exa/scripts/check_imports.py b/libs/partners/exa/scripts/check_imports.py index 58a460c1493..4a8b8a9cd86 100644 --- a/libs/partners/exa/scripts/check_imports.py +++ b/libs/partners/exa/scripts/check_imports.py @@ -1,3 +1,5 @@ +"""Check that the given files can be imported.""" + import sys import traceback from importlib.machinery import SourceFileLoader diff --git a/libs/partners/exa/tests/__init__.py b/libs/partners/exa/tests/__init__.py index e69de29bb2d..d0781463dec 100644 --- a/libs/partners/exa/tests/__init__.py +++ b/libs/partners/exa/tests/__init__.py @@ -0,0 +1 @@ +"""Exa tests.""" diff --git a/libs/partners/exa/tests/integration_tests/__init__.py b/libs/partners/exa/tests/integration_tests/__init__.py index e69de29bb2d..ea73d295f16 100644 --- a/libs/partners/exa/tests/integration_tests/__init__.py +++ b/libs/partners/exa/tests/integration_tests/__init__.py @@ -0,0 +1 @@ +"""Exa integration tests.""" diff --git a/libs/partners/exa/tests/integration_tests/test_compile.py b/libs/partners/exa/tests/integration_tests/test_compile.py index e44fac59a0f..bb2268fca3a 100644 --- a/libs/partners/exa/tests/integration_tests/test_compile.py +++ b/libs/partners/exa/tests/integration_tests/test_compile.py @@ -1,3 +1,5 @@ +"""Test that the integration tests compile.""" + import pytest # type: ignore[import-not-found, import-not-found] diff --git a/libs/partners/exa/tests/integration_tests/test_find_similar_tool.py b/libs/partners/exa/tests/integration_tests/test_find_similar_tool.py index a2c4283ac6b..7f04c6e637c 100644 --- a/libs/partners/exa/tests/integration_tests/test_find_similar_tool.py +++ b/libs/partners/exa/tests/integration_tests/test_find_similar_tool.py @@ -1,9 +1,12 @@ +"""Integration tests for Exa find similar tool.""" + from langchain_exa import ( ExaFindSimilarResults, # type: ignore[import-not-found, import-not-found] ) def test_similarity_tool() -> None: + """Test that the Exa find similar tool works.""" tool = ExaFindSimilarResults() res = tool.invoke( { diff --git a/libs/partners/exa/tests/integration_tests/test_retriever.py b/libs/partners/exa/tests/integration_tests/test_retriever.py index dc49e192dba..96218ef0e53 100644 --- a/libs/partners/exa/tests/integration_tests/test_retriever.py +++ b/libs/partners/exa/tests/integration_tests/test_retriever.py @@ -1,3 +1,5 @@ +"""Integration tests for `ExaSearchRetriever`.""" + from langchain_core.documents import ( Document, # type: ignore[import-not-found, import-not-found] ) @@ -6,6 +8,7 @@ from langchain_exa import ExaSearchRetriever def test_exa_retriever() -> None: + """Test basic functionality of the `ExaSearchRetriever`.""" retriever = ExaSearchRetriever() res = retriever.invoke("best time to visit japan") print(res) # noqa: T201 @@ -15,6 +18,7 @@ def test_exa_retriever() -> None: def test_exa_retriever_highlights() -> None: + """Test highlights feature of the `ExaSearchRetriever`.""" retriever = ExaSearchRetriever(highlights=True) res = retriever.invoke("best time to visit japan") print(res) # noqa: T201 @@ -29,6 +33,7 @@ def test_exa_retriever_highlights() -> None: def test_exa_retriever_advanced_features() -> None: + """Test advanced features of the `ExaSearchRetriever`.""" retriever = ExaSearchRetriever( k=3, text_contents_options={"max_characters": 1000}, summary=True, type="auto" ) diff --git a/libs/partners/exa/tests/integration_tests/test_search_tool.py b/libs/partners/exa/tests/integration_tests/test_search_tool.py index 46fed61dd5e..e5acd1a246e 100644 --- a/libs/partners/exa/tests/integration_tests/test_search_tool.py +++ b/libs/partners/exa/tests/integration_tests/test_search_tool.py @@ -1,9 +1,12 @@ +"""Integration tests for Exa search tool.""" + from langchain_exa import ( ExaSearchResults, # type: ignore[import-not-found, import-not-found] ) def test_search_tool() -> None: + """Test that the Exa search tool works.""" tool = ExaSearchResults() res = tool.invoke({"query": "best time to visit japan", "num_results": 5}) print(res) # noqa: T201 @@ -11,6 +14,7 @@ def test_search_tool() -> None: def test_search_tool_advanced_features() -> None: + """Test advanced features of the Exa search tool.""" tool = ExaSearchResults() res = tool.invoke( { diff --git a/libs/partners/exa/tests/unit_tests/__init__.py b/libs/partners/exa/tests/unit_tests/__init__.py index e69de29bb2d..11429f1766f 100644 --- a/libs/partners/exa/tests/unit_tests/__init__.py +++ b/libs/partners/exa/tests/unit_tests/__init__.py @@ -0,0 +1 @@ +"""Unit tests for `langchain_exa` package.""" diff --git a/libs/partners/exa/tests/unit_tests/test_imports.py b/libs/partners/exa/tests/unit_tests/test_imports.py index e7bb25b50cc..dcc4b3ebd6c 100644 --- a/libs/partners/exa/tests/unit_tests/test_imports.py +++ b/libs/partners/exa/tests/unit_tests/test_imports.py @@ -1,3 +1,5 @@ +"""Unit tests for imports in `langchain_exa`.""" + from langchain_exa import __all__ # type: ignore[import-not-found, import-not-found] EXPECTED_ALL = [ @@ -10,4 +12,5 @@ EXPECTED_ALL = [ def test_all_imports() -> None: + """Test that all expected imports are in `__all__`.""" assert sorted(EXPECTED_ALL) == sorted(__all__) diff --git a/libs/partners/exa/tests/unit_tests/test_standard.py b/libs/partners/exa/tests/unit_tests/test_standard.py index c92f5a05115..f510ccdaccd 100644 --- a/libs/partners/exa/tests/unit_tests/test_standard.py +++ b/libs/partners/exa/tests/unit_tests/test_standard.py @@ -1,3 +1,5 @@ +"""Standard unit tests for ExaSearchRetriever.""" + import pytest from pytest_benchmark.fixture import BenchmarkFixture # type: ignore[import-untyped] diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 0c69a527c8d..f63e0fe4741 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -10,8 +10,6 @@ from operator import itemgetter from typing import ( Any, Literal, - Optional, - Union, cast, ) @@ -303,9 +301,9 @@ class ChatFireworks(BaseChatModel): async_client: Any = Field(default=None, exclude=True) #: :meta private: model_name: str = Field(alias="model") """Model name to use.""" - temperature: Optional[float] = None + temperature: float | None = None """What sampling temperature to use.""" - stop: Optional[Union[str, list[str]]] = Field(default=None, alias="stop_sequences") + stop: str | list[str] | None = Field(default=None, alias="stop_sequences") """Default stop sequences.""" model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" @@ -325,12 +323,12 @@ class ChatFireworks(BaseChatModel): Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided. """ - fireworks_api_base: Optional[str] = Field( + fireworks_api_base: str | None = Field( alias="base_url", default_factory=from_env("FIREWORKS_API_BASE", default=None) ) """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" - request_timeout: Union[float, tuple[float, float], Any, None] = Field( + request_timeout: float | tuple[float, float] | Any | None = Field( default=None, alias="timeout" ) """Timeout for requests to Fireworks completion API. Can be ``float``, @@ -339,9 +337,9 @@ class ChatFireworks(BaseChatModel): """Whether to stream the results or not.""" n: int = 1 """Number of chat completions to generate for each prompt.""" - max_tokens: Optional[int] = None + max_tokens: int | None = None """Maximum number of tokens to generate.""" - max_retries: Optional[int] = None + max_retries: int | None = None """Maximum number of retries to make when generating.""" model_config = ConfigDict( @@ -401,7 +399,7 @@ class ChatFireworks(BaseChatModel): return params def _get_ls_params( - self, stop: Optional[list[str]] = None, **kwargs: Any + self, stop: list[str] | None = None, **kwargs: Any ) -> LangSmithParams: """Get standard params for tracing.""" params = self._get_invocation_params(stop=stop, **kwargs) @@ -417,7 +415,7 @@ class ChatFireworks(BaseChatModel): ls_params["ls_stop"] = ls_stop return ls_params - def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: + def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict: overall_token_usage: dict = {} system_fingerprint = None for output in llm_outputs: @@ -441,8 +439,8 @@ class ChatFireworks(BaseChatModel): def _stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) @@ -476,9 +474,9 @@ class ChatFireworks(BaseChatModel): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, # noqa: FBT001 + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + stream: bool | None = None, # noqa: FBT001 **kwargs: Any, ) -> ChatResult: should_stream = stream if stream is not None else self.streaming @@ -497,7 +495,7 @@ class ChatFireworks(BaseChatModel): return self._create_chat_result(response) def _create_message_dicts( - self, messages: list[BaseMessage], stop: Optional[list[str]] + self, messages: list[BaseMessage], stop: list[str] | None ) -> tuple[list[dict[str, Any]], dict[str, Any]]: params = self._default_params if stop is not None: @@ -505,7 +503,7 @@ class ChatFireworks(BaseChatModel): message_dicts = [_convert_message_to_dict(m) for m in messages] return message_dicts, params - def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: + def _create_chat_result(self, response: dict | BaseModel) -> ChatResult: generations = [] if not isinstance(response, dict): response = response.model_dump() @@ -536,8 +534,8 @@ class ChatFireworks(BaseChatModel): async def _astream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) @@ -573,9 +571,9 @@ class ChatFireworks(BaseChatModel): async def _agenerate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, # noqa: FBT001 + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + stream: bool | None = None, # noqa: FBT001 **kwargs: Any, ) -> ChatResult: should_stream = stream if stream is not None else self.streaming @@ -600,7 +598,7 @@ class ChatFireworks(BaseChatModel): return {"model_name": self.model_name, **self._default_params} def _get_invocation_params( - self, stop: Optional[list[str]] = None, **kwargs: Any + self, stop: list[str] | None = None, **kwargs: Any ) -> dict[str, Any]: """Get the parameters used to invoke the model.""" return { @@ -617,11 +615,9 @@ class ChatFireworks(BaseChatModel): def bind_tools( self, - tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], + tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool], *, - tool_choice: Optional[ - Union[dict, str, Literal["auto", "any", "none"], bool] # noqa: PYI051 - ] = None, + tool_choice: dict | str | bool | None = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, AIMessage]: """Bind tool-like objects to this chat model. @@ -634,10 +630,10 @@ class ChatFireworks(BaseChatModel): `langchain_core.utils.function_calling.convert_to_openai_tool`. tool_choice: Which tool to require the model to call. Must be the name of the single provided function, - ``'auto'`` to automatically determine which function to call + `'auto'` to automatically determine which function to call with the option to not call any function, `'any'` to enforce that some function is called, or a dict of the form: - ``{"type": "function", "function": {"name": <>}}``. + `{"type": "function", "function": {"name": <>}}`. **kwargs: Any additional parameters to pass to `langchain_fireworks.chat_models.ChatFireworks.bind` @@ -666,14 +662,14 @@ class ChatFireworks(BaseChatModel): def with_structured_output( self, - schema: Optional[Union[dict, type[BaseModel]]] = None, + schema: dict | type[BaseModel] | None = None, *, method: Literal[ "function_calling", "json_mode", "json_schema" ] = "function_calling", include_raw: bool = False, **kwargs: Any, - ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: + ) -> Runnable[LanguageModelInput, dict | BaseModel]: """Model wrapper that returns outputs formatted to match the given schema. Args: diff --git a/libs/partners/fireworks/langchain_fireworks/llms.py b/libs/partners/fireworks/langchain_fireworks/llms.py index 7f7fee8436f..7b245ce5e58 100644 --- a/libs/partners/fireworks/langchain_fireworks/llms.py +++ b/libs/partners/fireworks/langchain_fireworks/llms.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import Any, Optional +from typing import Any import requests from aiohttp import ClientSession, ClientTimeout @@ -56,35 +56,35 @@ class Fireworks(LLM): """ model: str """Model name. `(Available models) `__""" - temperature: Optional[float] = None + temperature: float | None = None """Model temperature.""" - top_p: Optional[float] = None + top_p: float | None = None """Used to dynamically adjust the number of choices for each predicted token based - on the cumulative probabilities. A value of ``1`` will always yield the same output. - A temperature less than ``1`` favors more correctness and is appropriate for - question answering or summarization. A value greater than ``1`` introduces more + on the cumulative probabilities. A value of `1` will always yield the same output. + A temperature less than `1` favors more correctness and is appropriate for + question answering or summarization. A value greater than `1` introduces more randomness in the output. """ model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for ``create`` call not explicitly specified.""" - top_k: Optional[int] = None + top_k: int | None = None """Used to limit the number of choices for the next predicted word or token. It specifies the maximum number of tokens to consider at each step, based on their probability of occurrence. This technique helps to speed up the generation process and can improve the quality of the generated text by focusing on the most likely options. """ - max_tokens: Optional[int] = None + max_tokens: int | None = None """The maximum number of tokens to generate.""" - repetition_penalty: Optional[float] = None + repetition_penalty: float | None = None """A number that controls the diversity of generated text by reducing the likelihood of repeated sequences. Higher values decrease repetition. """ - logprobs: Optional[int] = None + logprobs: int | None = None """An integer that specifies how many top token log probabilities are included in the response for each token generation step. """ - timeout: Optional[int] = 30 + timeout: int | None = 30 """Timeout in seconds for requests to the Fireworks API.""" model_config = ConfigDict( @@ -125,8 +125,8 @@ class Fireworks(LLM): def _call( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> str: """Call out to Fireworks's text generation endpoint. @@ -178,8 +178,8 @@ class Fireworks(LLM): async def _acall( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> str: """Call Fireworks model to get predictions based on the prompt. diff --git a/libs/partners/fireworks/pyproject.toml b/libs/partners/fireworks/pyproject.toml index 70101e2a760..c2e1e620925 100644 --- a/libs/partners/fireworks/pyproject.toml +++ b/libs/partners/fireworks/pyproject.toml @@ -108,8 +108,6 @@ ignore = [ "S112", # Rarely useful "RUF012", # Doesn't play well with Pydantic "SLF001", # Private member access - "UP007", # pyupgrade: non-pep604-annotation-union - "UP045", # pyupgrade: non-pep604-annotation-optional ] unfixable = ["B028"] # People should intentionally tune the stacklevel diff --git a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py index ab3683e5490..c1662d3a356 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py +++ b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py @@ -6,7 +6,7 @@ You will need FIREWORKS_API_KEY set in your environment to run these tests. from __future__ import annotations import json -from typing import Annotated, Any, Literal, Optional +from typing import Annotated, Any, Literal import pytest from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk @@ -48,7 +48,7 @@ async def test_astream() -> None: """Test streaming tokens from ChatFireworks.""" llm = ChatFireworks(model=_MODEL) - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None chunks_with_token_counts = 0 chunks_with_response_metadata = 0 async for token in llm.astream("I'm Pickle Rick"): diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 639838e9ab2..0fa0770289e 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -6,7 +6,7 @@ import json import warnings from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence from operator import itemgetter -from typing import Any, Literal, Optional, Union, cast +from typing import Any, Literal, cast from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -308,9 +308,9 @@ class ChatGroq(BaseChatModel): """Model name to use.""" temperature: float = 0.7 """What sampling temperature to use.""" - stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences") + stop: list[str] | str | None = Field(default=None, alias="stop_sequences") """Default stop sequences.""" - reasoning_format: Optional[Literal["parsed", "raw", "hidden"]] = Field(default=None) + reasoning_format: Literal["parsed", "raw", "hidden"] | None = Field(default=None) """The format for reasoning output. Groq will default to raw if left undefined. - ``'parsed'``: Separates reasoning into a dedicated field while keeping the @@ -325,7 +325,7 @@ class ChatGroq(BaseChatModel): See the `Groq documentation `__ for more details and a list of supported models. """ - reasoning_effort: Optional[str] = Field(default=None) + reasoning_effort: str | None = Field(default=None) """The level of effort the model will put into reasoning. Groq will default to enabling reasoning if left undefined. @@ -335,20 +335,18 @@ class ChatGroq(BaseChatModel): """ model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - groq_api_key: Optional[SecretStr] = Field( + groq_api_key: SecretStr | None = Field( alias="api_key", default_factory=secret_from_env("GROQ_API_KEY", default=None) ) """Automatically inferred from env var ``GROQ_API_KEY`` if not provided.""" - groq_api_base: Optional[str] = Field( + groq_api_base: str | None = Field( alias="base_url", default_factory=from_env("GROQ_API_BASE", default=None) ) """Base URL path for API requests. Leave blank if not using a proxy or service emulator.""" # to support explicit proxy for Groq - groq_proxy: Optional[str] = Field( - default_factory=from_env("GROQ_PROXY", default=None) - ) - request_timeout: Union[float, tuple[float, float], Any, None] = Field( + groq_proxy: str | None = Field(default_factory=from_env("GROQ_PROXY", default=None)) + request_timeout: float | tuple[float, float] | Any | None = Field( default=None, alias="timeout" ) """Timeout for requests to Groq completion API. Can be float, ``httpx.Timeout`` or @@ -359,7 +357,7 @@ class ChatGroq(BaseChatModel): """Whether to stream the results or not.""" n: int = 1 """Number of chat completions to generate for each prompt.""" - max_tokens: Optional[int] = None + max_tokens: int | None = None """Maximum number of tokens to generate.""" service_tier: Literal["on_demand", "flex", "auto"] = Field(default="on_demand") """Optional parameter that you can include to specify the service tier you'd like to @@ -369,20 +367,20 @@ class ChatGroq(BaseChatModel): - ``'flex'``: On-demand processing when capacity is available, with rapid timeouts if resources are constrained. Provides balance between performance and reliability for workloads that don't require guaranteed processing. - - ``'auto'``: Uses on-demand rate limits, then falls back to ``'flex'`` if those + - `'auto'`: Uses on-demand rate limits, then falls back to ``'flex'`` if those limits are exceeded See the `Groq documentation `__ for more details and a list of service tiers and descriptions. """ - default_headers: Union[Mapping[str, str], None] = None - default_query: Union[Mapping[str, object], None] = None + default_headers: Mapping[str, str] | None = None + default_query: Mapping[str, object] | None = None # Configure a custom httpx client. See the # [httpx documentation](https://www.python-httpx.org/api/#client) for more details. - http_client: Union[Any, None] = None + http_client: Any | None = None """Optional ``httpx.Client``.""" - http_async_client: Union[Any, None] = None + http_async_client: Any | None = None """Optional ``httpx.AsyncClient``. Only used for async invocations. Must specify ``http_client`` as well if you'd like a custom client for sync invocations.""" @@ -490,7 +488,7 @@ class ChatGroq(BaseChatModel): return "groq-chat" def _get_ls_params( - self, stop: Optional[list[str]] = None, **kwargs: Any + self, stop: list[str] | None = None, **kwargs: Any ) -> LangSmithParams: """Get standard params for tracing.""" params = self._get_invocation_params(stop=stop, **kwargs) @@ -510,9 +508,9 @@ class ChatGroq(BaseChatModel): self, *, async_api: bool, - run_manager: Optional[ - Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun] - ] = None, + run_manager: CallbackManagerForLLMRun + | AsyncCallbackManagerForLLMRun + | None = None, **kwargs: Any, ) -> bool: """Determine if a given model call should hit the streaming API.""" @@ -532,8 +530,8 @@ class ChatGroq(BaseChatModel): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: if self.streaming: @@ -552,8 +550,8 @@ class ChatGroq(BaseChatModel): async def _agenerate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: if self.streaming: @@ -573,8 +571,8 @@ class ChatGroq(BaseChatModel): def _stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) @@ -625,8 +623,8 @@ class ChatGroq(BaseChatModel): async def _astream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) @@ -738,7 +736,7 @@ class ChatGroq(BaseChatModel): return ChatResult(generations=generations, llm_output=llm_output) def _create_message_dicts( - self, messages: list[BaseMessage], stop: Optional[list[str]] + self, messages: list[BaseMessage], stop: list[str] | None ) -> tuple[list[dict[str, Any]], dict[str, Any]]: params = self._default_params if stop is not None: @@ -746,7 +744,7 @@ class ChatGroq(BaseChatModel): message_dicts = [_convert_message_to_dict(m) for m in messages] return message_dicts, params - def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: + def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict: overall_token_usage: dict = {} system_fingerprint = None for output in llm_outputs: @@ -771,11 +769,9 @@ class ChatGroq(BaseChatModel): def bind_tools( self, - tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], + tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool], *, - tool_choice: Optional[ - Union[dict, str, Literal["auto", "any", "none"], bool] # noqa: PYI051 - ] = None, + tool_choice: dict | str | bool | None = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, AIMessage]: """Bind tool-like objects to this chat model. @@ -786,10 +782,10 @@ class ChatGroq(BaseChatModel): `langchain_core.utils.function_calling.convert_to_openai_tool`. tool_choice: Which tool to require the model to call. Must be the name of the single provided function, - "auto" to automatically determine which function to call - with the option to not call any function, "any" to enforce that some + `'auto'` to automatically determine which function to call + with the option to not call any function, `'any'` to enforce that some function is called, or a dict of the form: - ``{"type": "function", "function": {"name": <>}}``. + `{"type": "function", "function": {"name": <>}}`. **kwargs: Any additional parameters to pass to the `langchain.runnable.Runnable` constructor. @@ -820,7 +816,7 @@ class ChatGroq(BaseChatModel): def with_structured_output( self, - schema: Optional[Union[dict, type[BaseModel]]] = None, + schema: dict | type[BaseModel] | None = None, *, method: Literal[ "function_calling", "json_mode", "json_schema" diff --git a/libs/partners/groq/pyproject.toml b/libs/partners/groq/pyproject.toml index 9adc8052acf..b2b8df19cf6 100644 --- a/libs/partners/groq/pyproject.toml +++ b/libs/partners/groq/pyproject.toml @@ -62,8 +62,6 @@ ignore = [ "S112", # Rarely useful "RUF012", # Doesn't play well with Pydantic "SLF001", # Private member access - "UP007", # pyupgrade: non-pep604-annotation-union - "UP045", # pyupgrade: non-pep604-annotation-optional "PLR0911", "PLR0912", "C901", diff --git a/libs/partners/groq/tests/integration_tests/test_chat_models.py b/libs/partners/groq/tests/integration_tests/test_chat_models.py index faa1a65573a..2e25e7ae627 100644 --- a/libs/partners/groq/tests/integration_tests/test_chat_models.py +++ b/libs/partners/groq/tests/integration_tests/test_chat_models.py @@ -3,7 +3,7 @@ from __future__ import annotations import json -from typing import Any, Optional, cast +from typing import Any, cast import pytest from groq import BadRequestError @@ -102,7 +102,7 @@ async def test_astream() -> None: """Test streaming tokens from Groq.""" chat = ChatGroq(model=DEFAULT_MODEL_NAME, max_tokens=10) - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None chunks_with_token_counts = 0 chunks_with_response_metadata = 0 async for token in chat.astream("Welcome to the Groqetship!"): @@ -254,7 +254,7 @@ def test_reasoning_output_stream() -> None: HumanMessage(content="I love programming."), ] - full_response: Optional[AIMessageChunk] = None + full_response: AIMessageChunk | None = None for token in chat.stream(message): assert isinstance(token, AIMessageChunk) diff --git a/libs/partners/groq/tests/unit_tests/fake/callbacks.py b/libs/partners/groq/tests/unit_tests/fake/callbacks.py index f38e8b56dc8..27a0169fddd 100644 --- a/libs/partners/groq/tests/unit_tests/fake/callbacks.py +++ b/libs/partners/groq/tests/unit_tests/fake/callbacks.py @@ -3,7 +3,7 @@ from __future__ import annotations from itertools import chain -from typing import Any, Optional, Union +from typing import Any from uuid import UUID from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler @@ -26,7 +26,7 @@ class BaseFakeCallbackHandler(BaseModel): ignore_chat_model_: bool = False # to allow for similar callback handlers that are not technically equal - fake_id: Union[str, None] = None + fake_id: str | None = None # add finer-grained counters for easier debugging of failing tests chain_starts: int = 0 @@ -270,7 +270,7 @@ class FakeCallbackHandlerWithChatStart(FakeCallbackHandler): messages: list[list[BaseMessage]], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: assert all(isinstance(m, BaseMessage) for m in chain(*messages)) diff --git a/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py b/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py index fc26e8e902a..24048c245f3 100644 --- a/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py +++ b/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py @@ -7,7 +7,7 @@ import json from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence from dataclasses import dataclass from operator import itemgetter -from typing import Any, Literal, Optional, Union, cast +from typing import Any, Literal, cast from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -316,7 +316,7 @@ class ChatHuggingFace(BaseChatModel): the HuggingFace Hub. Setup: - Install ``langchain-huggingface`` and ensure your Hugging Face token + Install `langchain-huggingface` and ensure your Hugging Face token is saved. .. code-block:: bash @@ -478,33 +478,33 @@ class ChatHuggingFace(BaseChatModel): HuggingFaceHub, or HuggingFacePipeline.""" tokenizer: Any = None """Tokenizer for the model. Only used for HuggingFacePipeline.""" - model_id: Optional[str] = None + model_id: str | None = None """Model ID for the model. Only used for HuggingFaceEndpoint.""" - temperature: Optional[float] = None + temperature: float | None = None """What sampling temperature to use.""" - stop: Optional[Union[str, list[str]]] = Field(default=None, alias="stop_sequences") + stop: str | list[str] | None = Field(default=None, alias="stop_sequences") """Default stop sequences.""" - presence_penalty: Optional[float] = None + presence_penalty: float | None = None """Penalizes repeated tokens.""" - frequency_penalty: Optional[float] = None + frequency_penalty: float | None = None """Penalizes repeated tokens according to frequency.""" - seed: Optional[int] = None + seed: int | None = None """Seed for generation""" - logprobs: Optional[bool] = None + logprobs: bool | None = None """Whether to return logprobs.""" - top_logprobs: Optional[int] = None + top_logprobs: int | None = None """Number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to true if this parameter is used.""" - logit_bias: Optional[dict[int, int]] = None + logit_bias: dict[int, int] | None = None """Modify the likelihood of specified tokens appearing in the completion.""" streaming: bool = False """Whether to stream the results or not.""" - n: Optional[int] = None + n: int | None = None """Number of chat completions to generate for each prompt.""" - top_p: Optional[float] = None + top_p: float | None = None """Total probability mass of tokens to consider at each step.""" - max_tokens: Optional[int] = None + max_tokens: int | None = None """Maximum number of tokens to generate.""" model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" @@ -558,9 +558,9 @@ class ChatHuggingFace(BaseChatModel): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, # noqa: FBT001 + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + stream: bool | None = None, # noqa: FBT001 **kwargs: Any, ) -> ChatResult: should_stream = stream if stream is not None else self.streaming @@ -599,9 +599,9 @@ class ChatHuggingFace(BaseChatModel): async def _agenerate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, # noqa: FBT001 + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + stream: bool | None = None, # noqa: FBT001 **kwargs: Any, ) -> ChatResult: if _is_huggingface_textgen_inference(self.llm): @@ -638,8 +638,8 @@ class ChatHuggingFace(BaseChatModel): def _stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: if _is_huggingface_endpoint(self.llm): @@ -687,8 +687,8 @@ class ChatHuggingFace(BaseChatModel): async def _astream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) @@ -779,7 +779,7 @@ class ChatHuggingFace(BaseChatModel): self.model_id = self.llm.repo_id return if _is_huggingface_textgen_inference(self.llm): - endpoint_url: Optional[str] = self.llm.inference_server_url + endpoint_url: str | None = self.llm.inference_server_url if _is_huggingface_pipeline(self.llm): from transformers import AutoTokenizer # type: ignore[import] @@ -809,11 +809,9 @@ class ChatHuggingFace(BaseChatModel): def bind_tools( self, - tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]], + tools: Sequence[dict[str, Any] | type | Callable | BaseTool], *, - tool_choice: Optional[ - Union[dict, str, Literal["auto", "none", "required"], bool] # noqa: PYI051 - ] = None, + tool_choice: dict | str | bool | None = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, AIMessage]: """Bind tool-like objects to this chat model. @@ -826,7 +824,7 @@ class ChatHuggingFace(BaseChatModel): `langchain_core.utils.function_calling.convert_to_openai_tool`. tool_choice: Which tool to require the model to call. Must be the name of the single provided function or - ``'auto'`` to automatically determine which function to call + `'auto'` to automatically determine which function to call (if any), or a dict of the form: {"type": "function", "function": {"name": <>}}. **kwargs: Any additional parameters to pass to the @@ -870,14 +868,14 @@ class ChatHuggingFace(BaseChatModel): def with_structured_output( self, - schema: Optional[Union[dict, type[BaseModel]]] = None, + schema: dict | type[BaseModel] | None = None, *, method: Literal[ "function_calling", "json_mode", "json_schema" ] = "function_calling", include_raw: bool = False, **kwargs: Any, - ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: + ) -> Runnable[LanguageModelInput, dict | BaseModel]: """Model wrapper that returns outputs formatted to match the given schema. Args: @@ -948,7 +946,7 @@ class ChatHuggingFace(BaseChatModel): if is_pydantic_schema: msg = "Pydantic schema is not supported for function calling" raise NotImplementedError(msg) - output_parser: Union[JsonOutputKeyToolsParser, JsonOutputParser] = ( + output_parser: JsonOutputKeyToolsParser | JsonOutputParser = ( JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True) ) elif method == "json_schema": @@ -966,9 +964,7 @@ class ChatHuggingFace(BaseChatModel): "schema": schema, }, ) - output_parser: Union[ # type: ignore[no-redef] - JsonOutputKeyToolsParser, JsonOutputParser - ] = JsonOutputParser() # type: ignore[arg-type] + output_parser = JsonOutputParser() # type: ignore[arg-type] elif method == "json_mode": llm = self.bind( response_format={"type": "json_object"}, @@ -977,9 +973,7 @@ class ChatHuggingFace(BaseChatModel): "schema": schema, }, ) - output_parser: Union[ # type: ignore[no-redef] - JsonOutputKeyToolsParser, JsonOutputParser - ] = JsonOutputParser() # type: ignore[arg-type] + output_parser = JsonOutputParser() # type: ignore[arg-type] else: msg = ( f"Unrecognized method argument. Expected one of 'function_calling' or " @@ -999,7 +993,7 @@ class ChatHuggingFace(BaseChatModel): return llm | output_parser def _create_message_dicts( - self, messages: list[BaseMessage], stop: Optional[list[str]] + self, messages: list[BaseMessage], stop: list[str] | None ) -> tuple[list[dict[str, Any]], dict[str, Any]]: params = self._default_params if stop is not None: diff --git a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py index 627431f2115..179c99d6901 100644 --- a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py +++ b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any from langchain_core.embeddings import Embeddings from pydantic import BaseModel, ConfigDict, Field @@ -40,7 +40,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): model_name: str = Field(default=DEFAULT_MODEL_NAME, alias="model") """Model name to use.""" - cache_folder: Optional[str] = None + cache_folder: str | None = None """Path to store models. Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.""" model_kwargs: dict[str, Any] = Field(default_factory=dict) diff --git a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface_endpoint.py b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface_endpoint.py index 2d3bfe88b79..dc2b2584879 100644 --- a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface_endpoint.py +++ b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface_endpoint.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -from typing import Any, Optional +from typing import Any from langchain_core.embeddings import Embeddings from langchain_core.utils import from_env @@ -35,20 +35,20 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings): client: Any = None #: :meta private: async_client: Any = None #: :meta private: - model: Optional[str] = None + model: str | None = None """Model name to use.""" - provider: Optional[str] = None + provider: str | None = None """Name of the provider to use for inference with the model specified in ``repo_id``. e.g. "sambanova". if not specified, defaults to HF Inference API. available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks).""" - repo_id: Optional[str] = None + repo_id: str | None = None """Huggingfacehub repository id, for backward compatibility.""" - task: Optional[str] = "feature-extraction" + task: str | None = "feature-extraction" """Task to call the model with.""" - model_kwargs: Optional[dict] = None + model_kwargs: dict | None = None """Keyword arguments to pass to the model.""" - huggingfacehub_api_token: Optional[str] = Field( + huggingfacehub_api_token: str | None = Field( default_factory=from_env("HUGGINGFACEHUB_API_TOKEN", default=None) ) diff --git a/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py b/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py index 5b34deccc45..cdc7423039c 100644 --- a/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py +++ b/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py @@ -4,7 +4,7 @@ import inspect import logging import os from collections.abc import AsyncIterator, Iterator, Mapping -from typing import Any, Optional +from typing import Any from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -79,42 +79,42 @@ class HuggingFaceEndpoint(LLM): """ # noqa: E501 - endpoint_url: Optional[str] = None + endpoint_url: str | None = None """Endpoint URL to use. If repo_id is not specified then this needs to given or should be pass as env variable in `HF_INFERENCE_ENDPOINT`""" - repo_id: Optional[str] = None + repo_id: str | None = None """Repo to use. If endpoint_url is not specified then this needs to given""" - provider: Optional[str] = None + provider: str | None = None """Name of the provider to use for inference with the model specified in `repo_id`. e.g. "cerebras". if not specified, Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks).""" - huggingfacehub_api_token: Optional[str] = Field( + huggingfacehub_api_token: str | None = Field( default_factory=from_env("HUGGINGFACEHUB_API_TOKEN", default=None) ) max_new_tokens: int = 512 """Maximum number of generated tokens""" - top_k: Optional[int] = None + top_k: int | None = None """The number of highest probability vocabulary tokens to keep for top-k-filtering.""" - top_p: Optional[float] = 0.95 + top_p: float | None = 0.95 """If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.""" - typical_p: Optional[float] = 0.95 + typical_p: float | None = 0.95 """Typical Decoding mass. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.""" - temperature: Optional[float] = 0.8 + temperature: float | None = 0.8 """The value used to module the logits distribution.""" - repetition_penalty: Optional[float] = None + repetition_penalty: float | None = None """The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.""" return_full_text: bool = False """Whether to prepend the prompt to the generated text""" - truncate: Optional[int] = None + truncate: int | None = None """Truncate inputs tokens to the given size""" stop_sequences: list[str] = Field(default_factory=list) """Stop generating tokens if a member of `stop_sequences` is generated""" - seed: Optional[int] = None + seed: int | None = None """Random sampling seed""" inference_server_url: str = "" """text-generation-inference instance base url""" @@ -134,7 +134,7 @@ class HuggingFaceEndpoint(LLM): model: str client: Any = None #: :meta private: async_client: Any = None #: :meta private: - task: Optional[str] = None + task: str | None = None """Task to call the model with. Should be a task that returns `generated_text`.""" model_config = ConfigDict( @@ -292,7 +292,7 @@ class HuggingFaceEndpoint(LLM): return "huggingface_endpoint" def _invocation_params( - self, runtime_stop: Optional[list[str]], **kwargs: Any + self, runtime_stop: list[str] | None, **kwargs: Any ) -> dict[str, Any]: params = {**self._default_params, **kwargs} params["stop"] = params["stop"] + (runtime_stop or []) @@ -301,8 +301,8 @@ class HuggingFaceEndpoint(LLM): def _call( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> str: """Call out to HuggingFace Hub's inference endpoint.""" @@ -331,8 +331,8 @@ class HuggingFaceEndpoint(LLM): async def _acall( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> str: invocation_params = self._invocation_params(stop, **kwargs) @@ -361,8 +361,8 @@ class HuggingFaceEndpoint(LLM): def _stream( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: invocation_params = self._invocation_params(stop, **kwargs) @@ -371,13 +371,13 @@ class HuggingFaceEndpoint(LLM): prompt, **invocation_params, stream=True ): # identify stop sequence in generated text, if any - stop_seq_found: Optional[str] = None + stop_seq_found: str | None = None for stop_seq in invocation_params["stop"]: if stop_seq in response: stop_seq_found = stop_seq # identify text to yield - text: Optional[str] = None + text: str | None = None if stop_seq_found: text = response[: response.index(stop_seq_found)] else: @@ -398,8 +398,8 @@ class HuggingFaceEndpoint(LLM): async def _astream( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: invocation_params = self._invocation_params(stop, **kwargs) @@ -407,13 +407,13 @@ class HuggingFaceEndpoint(LLM): prompt, **invocation_params, stream=True ): # identify stop sequence in generated text, if any - stop_seq_found: Optional[str] = None + stop_seq_found: str | None = None for stop_seq in invocation_params["stop"]: if stop_seq in response: stop_seq_found = stop_seq # identify text to yield - text: Optional[str] = None + text: str | None = None if stop_seq_found: text = response[: response.index(stop_seq_found)] else: diff --git a/libs/partners/huggingface/langchain_huggingface/llms/huggingface_pipeline.py b/libs/partners/huggingface/langchain_huggingface/llms/huggingface_pipeline.py index 39f3bd6877c..c4d1fedcfd8 100644 --- a/libs/partners/huggingface/langchain_huggingface/llms/huggingface_pipeline.py +++ b/libs/partners/huggingface/langchain_huggingface/llms/huggingface_pipeline.py @@ -3,7 +3,7 @@ from __future__ import annotations # type: ignore[import-not-found] import importlib.util import logging from collections.abc import Iterator, Mapping -from typing import Any, Optional +from typing import Any from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import BaseLLM @@ -71,13 +71,13 @@ class HuggingFacePipeline(BaseLLM): """ pipeline: Any = None #: :meta private: - model_id: Optional[str] = None + model_id: str | None = None """The model name. If not set explicitly by the user, it will be inferred from the provided pipeline (if available). If neither is provided, the DEFAULT_MODEL_ID will be used.""" - model_kwargs: Optional[dict] = None + model_kwargs: dict | None = None """Keyword arguments passed to the model.""" - pipeline_kwargs: Optional[dict] = None + pipeline_kwargs: dict | None = None """Keyword arguments passed to the pipeline.""" batch_size: int = DEFAULT_BATCH_SIZE """Batch size to use when passing multiple documents to generate.""" @@ -103,10 +103,10 @@ class HuggingFacePipeline(BaseLLM): model_id: str, task: str, backend: str = "default", - device: Optional[int] = None, - device_map: Optional[str] = None, - model_kwargs: Optional[dict] = None, - pipeline_kwargs: Optional[dict] = None, + device: int | None = None, + device_map: str | None = None, + model_kwargs: dict | None = None, + pipeline_kwargs: dict | None = None, batch_size: int = DEFAULT_BATCH_SIZE, **kwargs: Any, ) -> HuggingFacePipeline: @@ -311,8 +311,8 @@ class HuggingFacePipeline(BaseLLM): def _generate( self, prompts: list[str], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> LLMResult: # List to hold all results @@ -363,8 +363,8 @@ class HuggingFacePipeline(BaseLLM): def _stream( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: from threading import Thread diff --git a/libs/partners/huggingface/langchain_huggingface/utils/import_utils.py b/libs/partners/huggingface/langchain_huggingface/utils/import_utils.py index f086e327d0b..f217517a980 100644 --- a/libs/partners/huggingface/langchain_huggingface/utils/import_utils.py +++ b/libs/partners/huggingface/langchain_huggingface/utils/import_utils.py @@ -3,7 +3,6 @@ from __future__ import annotations import importlib.metadata import importlib.util import operator as op -from typing import Union from packaging import version @@ -44,7 +43,7 @@ _openvino_available = importlib.util.find_spec("openvino") is not None # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 def compare_versions( - library_or_version: Union[str, version.Version], + library_or_version: str | version.Version, operation: str, requirement_version: str, ) -> bool: diff --git a/libs/partners/huggingface/pyproject.toml b/libs/partners/huggingface/pyproject.toml index e8e669aedf9..00f279553d0 100644 --- a/libs/partners/huggingface/pyproject.toml +++ b/libs/partners/huggingface/pyproject.toml @@ -119,8 +119,6 @@ ignore = [ "S112", # Rarely useful "RUF012", # Doesn't play well with Pydantic "SLF001", # Private member access - "UP007", # pyupgrade: non-pep604-annotation-union - "UP045", # pyupgrade: non-pep604-annotation-optional ] unfixable = ["B028"] # People should intentionally tune the stacklevel diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index edea4f27bc4..ff46ad35c75 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -13,8 +13,6 @@ from typing import ( TYPE_CHECKING, Any, Literal, - Optional, - Union, cast, ) @@ -93,9 +91,7 @@ global_ssl_context = ssl.create_default_context(cafile=certifi.where()) def _create_retry_decorator( llm: ChatMistralAI, - run_manager: Optional[ - Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] - ] = None, + run_manager: AsyncCallbackManagerForLLMRun | CallbackManagerForLLMRun | None = None, ) -> Callable[[Any], Any]: """Return a tenacity retry decorator, preconfigured to handle exceptions.""" errors = [httpx.RequestError, httpx.StreamError] @@ -211,7 +207,7 @@ async def _aiter_sse( async def acompletion_with_retry( llm: ChatMistralAI, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Any: """Use tenacity to retry the async completion call.""" @@ -397,23 +393,23 @@ class ChatMistralAI(BaseChatModel): async_client: httpx.AsyncClient = Field( # type: ignore[assignment] # : meta private: default=None, exclude=True ) #: :meta private: - mistral_api_key: Optional[SecretStr] = Field( + mistral_api_key: SecretStr | None = Field( alias="api_key", default_factory=secret_from_env("MISTRAL_API_KEY", default=None), ) - endpoint: Optional[str] = Field(default=None, alias="base_url") + endpoint: str | None = Field(default=None, alias="base_url") max_retries: int = 5 timeout: int = 120 max_concurrent_requests: int = 64 model: str = Field(default="mistral-small", alias="model_name") temperature: float = 0.7 - max_tokens: Optional[int] = None + max_tokens: int | None = None top_p: float = 1 """Decode using nucleus sampling: consider the smallest set of tokens whose probability sum is at least ``top_p``. Must be in the closed interval ``[0.0, 1.0]``.""" - random_seed: Optional[int] = None - safe_mode: Optional[bool] = None + random_seed: int | None = None + safe_mode: bool | None = None streaming: bool = False model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any invocation parameters not explicitly specified.""" @@ -445,7 +441,7 @@ class ChatMistralAI(BaseChatModel): return {k: v for k, v in defaults.items() if v is not None} def _get_ls_params( - self, stop: Optional[list[str]] = None, **kwargs: Any + self, stop: list[str] | None = None, **kwargs: Any ) -> LangSmithParams: """Get standard params for tracing.""" params = self._get_invocation_params(stop=stop, **kwargs) @@ -467,7 +463,7 @@ class ChatMistralAI(BaseChatModel): return self._default_params def completion_with_retry( - self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any + self, run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any ) -> Any: """Use tenacity to retry the completion call.""" retry_decorator = _create_retry_decorator(self, run_manager=run_manager) @@ -496,7 +492,7 @@ class ChatMistralAI(BaseChatModel): return _completion_with_retry(**kwargs) - def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: + def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict: overall_token_usage: dict = {} for output in llm_outputs: if output is None: @@ -515,7 +511,7 @@ class ChatMistralAI(BaseChatModel): def validate_environment(self) -> Self: """Validate api key, python package exists, temperature, and top_p.""" if isinstance(self.mistral_api_key, SecretStr): - api_key_str: Optional[str] = self.mistral_api_key.get_secret_value() + api_key_str: str | None = self.mistral_api_key.get_secret_value() else: api_key_str = self.mistral_api_key @@ -563,9 +559,9 @@ class ChatMistralAI(BaseChatModel): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, # noqa: FBT001 + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + stream: bool | None = None, # noqa: FBT001 **kwargs: Any, ) -> ChatResult: should_stream = stream if stream is not None else self.streaming @@ -608,7 +604,7 @@ class ChatMistralAI(BaseChatModel): return ChatResult(generations=generations, llm_output=llm_output) def _create_message_dicts( - self, messages: list[BaseMessage], stop: Optional[list[str]] + self, messages: list[BaseMessage], stop: list[str] | None ) -> tuple[list[dict], dict[str, Any]]: params = self._client_params if stop is not None or "stop" in params: @@ -623,8 +619,8 @@ class ChatMistralAI(BaseChatModel): def _stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) @@ -649,8 +645,8 @@ class ChatMistralAI(BaseChatModel): async def _astream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) @@ -675,9 +671,9 @@ class ChatMistralAI(BaseChatModel): async def _agenerate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, # noqa: FBT001 + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + stream: bool | None = None, # noqa: FBT001 **kwargs: Any, ) -> ChatResult: should_stream = stream if stream is not None else self.streaming @@ -696,8 +692,8 @@ class ChatMistralAI(BaseChatModel): def bind_tools( self, - tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]], - tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None, # noqa: PYI051 + tools: Sequence[dict[str, Any] | type | Callable | BaseTool], + tool_choice: dict | str | Literal["auto", "any"] | None = None, # noqa: PYI051 **kwargs: Any, ) -> Runnable[LanguageModelInput, AIMessage]: """Bind tool-like objects to this chat model. @@ -710,7 +706,7 @@ class ChatMistralAI(BaseChatModel): `langchain_core.utils.function_calling.convert_to_openai_tool`. tool_choice: Which tool to require the model to call. Must be the name of the single provided function or - ``'auto'`` to automatically determine which function to call + `'auto'` to automatically determine which function to call (if any), or a dict of the form: {"type": "function", "function": {"name": <>}}. kwargs: Any additional parameters are passed directly to @@ -738,14 +734,14 @@ class ChatMistralAI(BaseChatModel): def with_structured_output( self, - schema: Optional[Union[dict, type]] = None, + schema: dict | type | None = None, *, method: Literal[ "function_calling", "json_mode", "json_schema" ] = "function_calling", include_raw: bool = False, **kwargs: Any, - ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: + ) -> Runnable[LanguageModelInput, dict | BaseModel]: r"""Model wrapper that returns outputs formatted to match the given schema. Args: @@ -1085,7 +1081,7 @@ class ChatMistralAI(BaseChatModel): def _convert_to_openai_response_format( - schema: Union[dict[str, Any], type], *, strict: Optional[bool] = None + schema: dict[str, Any] | type, *, strict: bool | None = None ) -> dict: """Perform same op as in ChatOpenAI, but do not pass through Pydantic BaseModels.""" if ( diff --git a/libs/partners/mistralai/pyproject.toml b/libs/partners/mistralai/pyproject.toml index 1459b8dd814..7b3f93045e1 100644 --- a/libs/partners/mistralai/pyproject.toml +++ b/libs/partners/mistralai/pyproject.toml @@ -62,8 +62,6 @@ ignore = [ "S112", # Rarely useful "RUF012", # Doesn't play well with Pydantic "SLF001", # Private member access - "UP007", # pyupgrade: non-pep604-annotation-union - "UP045", # pyupgrade: non-pep604-annotation-optional "TD", "PLR0912", "C901", diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index 3e162d61063..a013e3a2bd9 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -5,7 +5,7 @@ from __future__ import annotations import json import logging import time -from typing import Any, Optional +from typing import Any import pytest from httpx import ReadTimeout @@ -33,7 +33,7 @@ async def test_astream() -> None: """Test streaming tokens from ChatMistralAI.""" llm = ChatMistralAI() - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None chunks_with_token_counts = 0 chunks_with_response_metadata = 0 async for token in llm.astream("Hello"): diff --git a/libs/partners/nomic/langchain_nomic/embeddings.py b/libs/partners/nomic/langchain_nomic/embeddings.py index 454c8847ffe..86f940dc2c0 100644 --- a/libs/partners/nomic/langchain_nomic/embeddings.py +++ b/libs/partners/nomic/langchain_nomic/embeddings.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import Literal, Optional, overload +from typing import Literal, overload import nomic # type: ignore[import] from langchain_core.embeddings import Embeddings @@ -28,8 +28,8 @@ class NomicEmbeddings(Embeddings): self, *, model: str, - nomic_api_key: Optional[str] = ..., - dimensionality: Optional[int] = ..., + nomic_api_key: str | None = ..., + dimensionality: int | None = ..., inference_mode: Literal["remote"] = ..., ) -> None: ... @@ -38,10 +38,10 @@ class NomicEmbeddings(Embeddings): self, *, model: str, - nomic_api_key: Optional[str] = ..., - dimensionality: Optional[int] = ..., + nomic_api_key: str | None = ..., + dimensionality: int | None = ..., inference_mode: Literal["local", "dynamic"], - device: Optional[str] = ..., + device: str | None = ..., ) -> None: ... @overload @@ -49,21 +49,21 @@ class NomicEmbeddings(Embeddings): self, *, model: str, - nomic_api_key: Optional[str] = ..., - dimensionality: Optional[int] = ..., + nomic_api_key: str | None = ..., + dimensionality: int | None = ..., inference_mode: str, - device: Optional[str] = ..., + device: str | None = ..., ) -> None: ... def __init__( self, *, model: str, - nomic_api_key: Optional[str] = None, - dimensionality: Optional[int] = None, + nomic_api_key: str | None = None, + dimensionality: int | None = None, inference_mode: str = "remote", - device: Optional[str] = None, - vision_model: Optional[str] = None, + device: str | None = None, + vision_model: str | None = None, ): """Initialize NomicEmbeddings model. diff --git a/libs/partners/nomic/pyproject.toml b/libs/partners/nomic/pyproject.toml index 626e4c8558d..b344c896071 100644 --- a/libs/partners/nomic/pyproject.toml +++ b/libs/partners/nomic/pyproject.toml @@ -61,8 +61,6 @@ ignore = [ "S112", # Rarely useful "RUF012", # Doesn't play well with Pydantic "SLF001", # Private member access - "UP007", # pyupgrade: non-pep604-annotation-union - "UP045", # pyupgrade: non-pep604-annotation-optional # TODO "PLR0913", diff --git a/libs/partners/ollama/langchain_ollama/_utils.py b/libs/partners/ollama/langchain_ollama/_utils.py index 45e85e9932d..301d6975d24 100644 --- a/libs/partners/ollama/langchain_ollama/_utils.py +++ b/libs/partners/ollama/langchain_ollama/_utils.py @@ -3,7 +3,6 @@ from __future__ import annotations import base64 -from typing import Optional from urllib.parse import unquote, urlparse from httpx import ConnectError @@ -49,8 +48,8 @@ def validate_model(client: Client, model_name: str) -> None: def parse_url_with_auth( - url: Optional[str], -) -> tuple[Optional[str], Optional[dict[str, str]]]: + url: str | None, +) -> tuple[str | None, dict[str, str] | None]: """Parse URL and extract `userinfo` credentials for headers. Handles URLs of the form: `https://user:password@host:port/path` @@ -101,7 +100,7 @@ def parse_url_with_auth( def merge_auth_headers( client_kwargs: dict, - auth_headers: Optional[dict[str, str]], + auth_headers: dict[str, str] | None, ) -> None: """Merge authentication headers into client kwargs in-place. diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index 1ed910d0402..358741cd8b9 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -46,7 +46,7 @@ import json import logging from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence from operator import itemgetter -from typing import Any, Literal, Optional, Union, cast +from typing import Any, Literal, cast from uuid import uuid4 from langchain_core.callbacks import CallbackManagerForLLMRun @@ -96,13 +96,13 @@ log = logging.getLogger(__name__) def _get_usage_metadata_from_generation_info( - generation_info: Optional[Mapping[str, Any]], -) -> Optional[UsageMetadata]: + generation_info: Mapping[str, Any] | None, +) -> UsageMetadata | None: """Get usage metadata from Ollama generation info mapping.""" if generation_info is None: return None - input_tokens: Optional[int] = generation_info.get("prompt_eval_count") - output_tokens: Optional[int] = generation_info.get("eval_count") + input_tokens: int | None = generation_info.get("prompt_eval_count") + output_tokens: int | None = generation_info.get("eval_count") if input_tokens is not None and output_tokens is not None: return UsageMetadata( input_tokens=input_tokens, @@ -166,7 +166,7 @@ def _parse_json_string( def _parse_arguments_from_tool_call( raw_tool_call: dict[str, Any], -) -> Optional[dict[str, Any]]: +) -> dict[str, Any] | None: """Parse arguments by trying to parse any shallowly nested string-encoded JSON. Band-aid fix for issue in Ollama with inconsistent tool call argument structure. @@ -259,7 +259,7 @@ class ChatOllama(BaseChatModel): ???+ note "Setup" - Install ``langchain-ollama`` and download any models you want to use from ollama. + Install `langchain-ollama` and download any models you want to use from ollama. .. code-block:: bash @@ -523,7 +523,7 @@ class ChatOllama(BaseChatModel): model: str """Model name to use.""" - reasoning: Optional[Union[bool, str]] = None + reasoning: bool | str | None = None """Controls the reasoning/thinking mode for `supported models `__. - `True`: Enables reasoning mode. The model's reasoning process will be @@ -536,7 +536,7 @@ class ChatOllama(BaseChatModel): however, if the model's default behavior *is* to perform reasoning, think tags ()```` and ````) will be present within the main response content unless you set ``reasoning`` to `True`. - - `str`: e.g. ``'low'``, ``'medium'``, ``'high'``. Enables reasoning with a custom + - `str`: e.g. `'low'`, ``'medium'``, `'high'`. Enables reasoning with a custom intensity level. Currently, this is only supported ``gpt-oss``. See the `Ollama docs `__ for more information. @@ -548,13 +548,13 @@ class ChatOllama(BaseChatModel): !!! version-added "Added in version 0.3.4" """ - mirostat: Optional[int] = None + mirostat: int | None = None """Enable Mirostat sampling for controlling perplexity. - (Default: `0`, `0` = disabled, ``1`` = Mirostat, ``2`` = Mirostat 2.0) + (Default: `0`, `0` = disabled, `1` = Mirostat, `2` = Mirostat 2.0) """ - mirostat_eta: Optional[float] = None + mirostat_eta: float | None = None """Influences how quickly the algorithm responds to feedback from generated text. A lower learning rate will result in slower adjustments, while a higher learning @@ -563,7 +563,7 @@ class ChatOllama(BaseChatModel): (Default: ``0.1``) """ - mirostat_tau: Optional[float] = None + mirostat_tau: float | None = None """Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. @@ -571,19 +571,19 @@ class ChatOllama(BaseChatModel): (Default: ``5.0``) """ - num_ctx: Optional[int] = None + num_ctx: int | None = None """Sets the size of the context window used to generate the next token. (Default: ``2048``) """ - num_gpu: Optional[int] = None + num_gpu: int | None = None """The number of GPUs to use. - On macOS it defaults to ``1`` to enable metal support, `0` to disable. + On macOS it defaults to `1` to enable metal support, `0` to disable. """ - num_thread: Optional[int] = None + num_thread: int | None = None """Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. It is recommended to @@ -591,26 +591,26 @@ class ChatOllama(BaseChatModel): the logical number of cores). """ - num_predict: Optional[int] = None + num_predict: int | None = None """Maximum number of tokens to predict when generating text. (Default: ``128``, ``-1`` = infinite generation, ``-2`` = fill context) """ - repeat_last_n: Optional[int] = None + repeat_last_n: int | None = None """Sets how far back for the model to look back to prevent repetition. (Default: ``64``, `0` = disabled, ``-1`` = ``num_ctx``) """ - repeat_penalty: Optional[float] = None + repeat_penalty: float | None = None """Sets how strongly to penalize repetitions. A higher value (e.g., ``1.5``) will penalize repetitions more strongly, while a lower value (e.g., ``0.9``) will be more lenient. (Default: ``1.1``) """ - temperature: Optional[float] = None + temperature: float | None = None """The temperature of the model. Increasing the temperature will make the model answer more creatively. @@ -618,17 +618,17 @@ class ChatOllama(BaseChatModel): (Default: ``0.8``) """ - seed: Optional[int] = None + seed: int | None = None """Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. """ - stop: Optional[list[str]] = None + stop: list[str] | None = None """Sets the stop tokens to use.""" - tfs_z: Optional[float] = None + tfs_z: float | None = None """Tail free sampling. Used to reduce the impact of less probable tokens from the output. @@ -636,10 +636,10 @@ class ChatOllama(BaseChatModel): A higher value (e.g., ``2.0``) will reduce the impact more, while a value of ``1.0`` disables this setting. - (Default: ``1``) + (Default: `1`) """ - top_k: Optional[int] = None + top_k: int | None = None """Reduces the probability of generating nonsense. A higher value (e.g. ``100``) will give more diverse answers, while a lower value @@ -648,7 +648,7 @@ class ChatOllama(BaseChatModel): (Default: ``40``) """ - top_p: Optional[float] = None + top_p: float | None = None """Works together with top-k. A higher value (e.g., ``0.95``) will lead to more diverse text, while a lower value @@ -657,13 +657,13 @@ class ChatOllama(BaseChatModel): (Default: ``0.9``) """ - format: Optional[Union[Literal["", "json"], JsonSchemaValue]] = None + format: Literal["", "json"] | JsonSchemaValue | None = None """Specify the format of the output (options: ``'json'``, JSON schema).""" - keep_alive: Optional[Union[int, str]] = None + keep_alive: int | str | None = None """How long the model will stay loaded into memory.""" - base_url: Optional[str] = None + base_url: str | None = None """Base url the model is hosted under. If none, defaults to the Ollama client default. @@ -685,7 +685,7 @@ class ChatOllama(BaseChatModel): """ - client_kwargs: Optional[dict] = {} + client_kwargs: dict | None = {} """Additional kwargs to pass to the httpx clients. Pass headers in here. These arguments are passed to both synchronous and async clients. @@ -694,7 +694,7 @@ class ChatOllama(BaseChatModel): to synchronous and asynchronous clients. """ - async_client_kwargs: Optional[dict] = {} + async_client_kwargs: dict | None = {} """Additional kwargs to merge with `client_kwargs` before passing to httpx client. These are clients unique to the async client; for shared args use `client_kwargs`. @@ -702,7 +702,7 @@ class ChatOllama(BaseChatModel): For a full list of the params, see the `httpx documentation `__. """ - sync_client_kwargs: Optional[dict] = {} + sync_client_kwargs: dict | None = {} """Additional kwargs to merge with `client_kwargs` before passing to httpx client. These are clients unique to the sync client; for shared args use `client_kwargs`. @@ -719,7 +719,7 @@ class ChatOllama(BaseChatModel): def _chat_params( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> dict[str, Any]: """Assemble the parameters for a chat completion request. @@ -834,8 +834,8 @@ class ChatOllama(BaseChatModel): ollama_messages: list = [] for message in messages: role: str - tool_call_id: Optional[str] = None - tool_calls: Optional[list[dict[str, Any]]] = None + tool_call_id: str | None = None + tool_calls: list[dict[str, Any]] | None = None if isinstance(message, HumanMessage): role = "user" elif isinstance(message, AIMessage): @@ -925,9 +925,9 @@ class ChatOllama(BaseChatModel): async def _acreate_chat_stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, - ) -> AsyncIterator[Union[Mapping[str, Any], str]]: + ) -> AsyncIterator[Mapping[str, Any] | str]: chat_params = self._chat_params(messages, stop, **kwargs) if chat_params["stream"]: @@ -939,9 +939,9 @@ class ChatOllama(BaseChatModel): def _create_chat_stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, - ) -> Iterator[Union[Mapping[str, Any], str]]: + ) -> Iterator[Mapping[str, Any] | str]: chat_params = self._chat_params(messages, stop, **kwargs) if chat_params["stream"]: @@ -953,8 +953,8 @@ class ChatOllama(BaseChatModel): def _chat_stream_with_aggregation( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, verbose: bool = False, # noqa: FBT002 **kwargs: Any, ) -> ChatGenerationChunk: @@ -979,8 +979,8 @@ class ChatOllama(BaseChatModel): async def _achat_stream_with_aggregation( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, verbose: bool = False, # noqa: FBT002 **kwargs: Any, ) -> ChatGenerationChunk: @@ -1003,7 +1003,7 @@ class ChatOllama(BaseChatModel): return final_chunk def _get_ls_params( - self, stop: Optional[list[str]] = None, **kwargs: Any + self, stop: list[str] | None = None, **kwargs: Any ) -> LangSmithParams: """Get standard params for tracing.""" params = self._get_invocation_params(stop=stop, **kwargs) @@ -1020,8 +1020,8 @@ class ChatOllama(BaseChatModel): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: final_chunk = self._chat_stream_with_aggregation( @@ -1044,7 +1044,7 @@ class ChatOllama(BaseChatModel): def _iterate_over_stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: reasoning = kwargs.get("reasoning", self.reasoning) @@ -1106,8 +1106,8 @@ class ChatOllama(BaseChatModel): def _stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: for chunk in self._iterate_over_stream(messages, stop, **kwargs): @@ -1121,7 +1121,7 @@ class ChatOllama(BaseChatModel): async def _aiterate_over_stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: reasoning = kwargs.get("reasoning", self.reasoning) @@ -1183,8 +1183,8 @@ class ChatOllama(BaseChatModel): async def _astream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: async for chunk in self._aiterate_over_stream(messages, stop, **kwargs): @@ -1198,8 +1198,8 @@ class ChatOllama(BaseChatModel): async def _agenerate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: final_chunk = await self._achat_stream_with_aggregation( @@ -1226,9 +1226,9 @@ class ChatOllama(BaseChatModel): def bind_tools( self, - tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]], + tools: Sequence[dict[str, Any] | type | Callable | BaseTool], *, - tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None, # noqa: PYI051, ARG002 + tool_choice: dict | str | Literal["auto", "any"] | bool | None = None, # noqa: PYI051, ARG002 **kwargs: Any, ) -> Runnable[LanguageModelInput, AIMessage]: """Bind tool-like objects to this chat model. @@ -1249,12 +1249,12 @@ class ChatOllama(BaseChatModel): def with_structured_output( self, - schema: Union[dict, type], + schema: dict | type, *, method: Literal["function_calling", "json_mode", "json_schema"] = "json_schema", include_raw: bool = False, **kwargs: Any, - ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: + ) -> Runnable[LanguageModelInput, dict | BaseModel]: r"""Model wrapper that returns outputs formatted to match the given schema. Args: diff --git a/libs/partners/ollama/langchain_ollama/embeddings.py b/libs/partners/ollama/langchain_ollama/embeddings.py index 649fcdbf4b4..e6e0a3dec4d 100644 --- a/libs/partners/ollama/langchain_ollama/embeddings.py +++ b/libs/partners/ollama/langchain_ollama/embeddings.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any from langchain_core.embeddings import Embeddings from ollama import AsyncClient, Client @@ -128,7 +128,7 @@ class OllamaEmbeddings(BaseModel, Embeddings): """ - base_url: Optional[str] = None + base_url: str | None = None """Base url the model is hosted under. If none, defaults to the Ollama client default. @@ -150,7 +150,7 @@ class OllamaEmbeddings(BaseModel, Embeddings): """ - client_kwargs: Optional[dict] = {} + client_kwargs: dict | None = {} """Additional kwargs to pass to the httpx clients. Pass headers in here. These arguments are passed to both synchronous and async clients. @@ -159,7 +159,7 @@ class OllamaEmbeddings(BaseModel, Embeddings): to synchronous and asynchronous clients. """ - async_client_kwargs: Optional[dict] = {} + async_client_kwargs: dict | None = {} """Additional kwargs to merge with `client_kwargs` before passing to httpx client. These are clients unique to the async client; for shared args use `client_kwargs`. @@ -167,7 +167,7 @@ class OllamaEmbeddings(BaseModel, Embeddings): For a full list of the params, see the `httpx documentation `__. """ - sync_client_kwargs: Optional[dict] = {} + sync_client_kwargs: dict | None = {} """Additional kwargs to merge with `client_kwargs` before passing to httpx client. These are clients unique to the sync client; for shared args use `client_kwargs`. @@ -175,73 +175,73 @@ class OllamaEmbeddings(BaseModel, Embeddings): For a full list of the params, see the `httpx documentation `__. """ - _client: Optional[Client] = PrivateAttr(default=None) + _client: Client | None = PrivateAttr(default=None) """The client to use for making requests.""" - _async_client: Optional[AsyncClient] = PrivateAttr(default=None) + _async_client: AsyncClient | None = PrivateAttr(default=None) """The async client to use for making requests.""" - mirostat: Optional[int] = None + mirostat: int | None = None """Enable Mirostat sampling for controlling perplexity. - (default: `0`, `0` = disabled, ``1`` = Mirostat, ``2`` = Mirostat 2.0)""" + (default: `0`, `0` = disabled, `1` = Mirostat, `2` = Mirostat 2.0)""" - mirostat_eta: Optional[float] = None + mirostat_eta: float | None = None """Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: ``0.1``)""" - mirostat_tau: Optional[float] = None + mirostat_tau: float | None = None """Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: ``5.0``)""" - num_ctx: Optional[int] = None + num_ctx: int | None = None """Sets the size of the context window used to generate the next token. (Default: ``2048``) """ - num_gpu: Optional[int] = None - """The number of GPUs to use. On macOS it defaults to ``1`` to + num_gpu: int | None = None + """The number of GPUs to use. On macOS it defaults to `1` to enable metal support, `0` to disable.""" - keep_alive: Optional[int] = None + keep_alive: int | None = None """Controls how long the model will stay loaded into memory following the request (default: ``5m``) """ - num_thread: Optional[int] = None + num_thread: int | None = None """Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores).""" - repeat_last_n: Optional[int] = None + repeat_last_n: int | None = None """Sets how far back for the model to look back to prevent repetition. (Default: ``64``, `0` = disabled, ``-1`` = ``num_ctx``)""" - repeat_penalty: Optional[float] = None + repeat_penalty: float | None = None """Sets how strongly to penalize repetitions. A higher value (e.g., ``1.5``) will penalize repetitions more strongly, while a lower value (e.g., ``0.9``) will be more lenient. (Default: ``1.1``)""" - temperature: Optional[float] = None + temperature: float | None = None """The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: ``0.8``)""" - stop: Optional[list[str]] = None + stop: list[str] | None = None """Sets the stop tokens to use.""" - tfs_z: Optional[float] = None + tfs_z: float | None = None """Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., ``2.0``) will reduce the - impact more, while a value of ``1.0`` disables this setting. (default: ``1``)""" + impact more, while a value of ``1.0`` disables this setting. (default: `1`)""" - top_k: Optional[int] = None + top_k: int | None = None """Reduces the probability of generating nonsense. A higher value (e.g. ``100``) will give more diverse answers, while a lower value (e.g. ``10``) will be more conservative. (Default: ``40``)""" - top_p: Optional[float] = None + top_p: float | None = None """Works together with top-k. A higher value (e.g., ``0.95``) will lead to more diverse text, while a lower value (e.g., ``0.5``) will generate more focused and conservative text. (Default: ``0.9``)""" diff --git a/libs/partners/ollama/langchain_ollama/llms.py b/libs/partners/ollama/langchain_ollama/llms.py index b495069fd70..69f16e4a357 100644 --- a/libs/partners/ollama/langchain_ollama/llms.py +++ b/libs/partners/ollama/langchain_ollama/llms.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import AsyncIterator, Iterator, Mapping -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -22,7 +22,7 @@ class OllamaLLM(BaseLLM): """Ollama large language models. Setup: - Install ``langchain-ollama`` and install/run the Ollama server locally: + Install `langchain-ollama` and install/run the Ollama server locally: .. code-block:: bash @@ -112,7 +112,7 @@ class OllamaLLM(BaseLLM): model: str """Model name to use.""" - reasoning: Optional[bool] = None + reasoning: bool | None = None """Controls the reasoning/thinking mode for `supported models `__. @@ -132,71 +132,71 @@ class OllamaLLM(BaseLLM): !!! version-added "Added in version 0.3.4" """ - mirostat: Optional[int] = None + mirostat: int | None = None """Enable Mirostat sampling for controlling perplexity. - (default: `0`, `0` = disabled, ``1`` = Mirostat, ``2`` = Mirostat 2.0)""" + (default: `0`, `0` = disabled, `1` = Mirostat, `2` = Mirostat 2.0)""" - mirostat_eta: Optional[float] = None + mirostat_eta: float | None = None """Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: ``0.1``)""" - mirostat_tau: Optional[float] = None + mirostat_tau: float | None = None """Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: ``5.0``)""" - num_ctx: Optional[int] = None + num_ctx: int | None = None """Sets the size of the context window used to generate the next token. (Default: ``2048``)""" - num_gpu: Optional[int] = None - """The number of GPUs to use. On macOS it defaults to ``1`` to + num_gpu: int | None = None + """The number of GPUs to use. On macOS it defaults to `1` to enable metal support, `0` to disable.""" - num_thread: Optional[int] = None + num_thread: int | None = None """Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores).""" - num_predict: Optional[int] = None + num_predict: int | None = None """Maximum number of tokens to predict when generating text. (Default: ``128``, ``-1`` = infinite generation, ``-2`` = fill context)""" - repeat_last_n: Optional[int] = None + repeat_last_n: int | None = None """Sets how far back for the model to look back to prevent repetition. (Default: ``64``, `0` = disabled, ``-1`` = ``num_ctx``)""" - repeat_penalty: Optional[float] = None + repeat_penalty: float | None = None """Sets how strongly to penalize repetitions. A higher value (e.g., ``1.5``) will penalize repetitions more strongly, while a lower value (e.g., ``0.9``) will be more lenient. (Default: ``1.1``)""" - temperature: Optional[float] = None + temperature: float | None = None """The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: ``0.8``)""" - seed: Optional[int] = None + seed: int | None = None """Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt.""" - stop: Optional[list[str]] = None + stop: list[str] | None = None """Sets the stop tokens to use.""" - tfs_z: Optional[float] = None + tfs_z: float | None = None """Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., ``2.0``) will reduce the - impact more, while a value of 1.0 disables this setting. (default: ``1``)""" + impact more, while a value of 1.0 disables this setting. (default: `1`)""" - top_k: Optional[int] = None + top_k: int | None = None """Reduces the probability of generating nonsense. A higher value (e.g. ``100``) will give more diverse answers, while a lower value (e.g. ``10``) will be more conservative. (Default: ``40``)""" - top_p: Optional[float] = None + top_p: float | None = None """Works together with top-k. A higher value (e.g., ``0.95``) will lead to more diverse text, while a lower value (e.g., ``0.5``) will generate more focused and conservative text. (Default: ``0.9``)""" @@ -204,10 +204,10 @@ class OllamaLLM(BaseLLM): format: Literal["", "json"] = "" """Specify the format of the output (options: ``'json'``)""" - keep_alive: Optional[Union[int, str]] = None + keep_alive: int | str | None = None """How long the model will stay loaded into memory.""" - base_url: Optional[str] = None + base_url: str | None = None """Base url the model is hosted under. If none, defaults to the Ollama client default. @@ -229,7 +229,7 @@ class OllamaLLM(BaseLLM): """ - client_kwargs: Optional[dict] = {} + client_kwargs: dict | None = {} """Additional kwargs to pass to the httpx clients. Pass headers in here. These arguments are passed to both synchronous and async clients. @@ -238,7 +238,7 @@ class OllamaLLM(BaseLLM): to synchronous and asynchronous clients. """ - async_client_kwargs: Optional[dict] = {} + async_client_kwargs: dict | None = {} """Additional kwargs to merge with `client_kwargs` before passing to httpx client. These are clients unique to the async client; for shared args use `client_kwargs`. @@ -246,7 +246,7 @@ class OllamaLLM(BaseLLM): For a full list of the params, see the `httpx documentation `__. """ - sync_client_kwargs: Optional[dict] = {} + sync_client_kwargs: dict | None = {} """Additional kwargs to merge with `client_kwargs` before passing to httpx client. These are clients unique to the sync client; for shared args use `client_kwargs`. @@ -254,16 +254,16 @@ class OllamaLLM(BaseLLM): For a full list of the params, see the `httpx documentation `__. """ - _client: Optional[Client] = PrivateAttr(default=None) + _client: Client | None = PrivateAttr(default=None) """The client to use for making requests.""" - _async_client: Optional[AsyncClient] = PrivateAttr(default=None) + _async_client: AsyncClient | None = PrivateAttr(default=None) """The async client to use for making requests.""" def _generate_params( self, prompt: str, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> dict[str, Any]: if self.stop is not None and stop is not None: @@ -310,7 +310,7 @@ class OllamaLLM(BaseLLM): return "ollama-llm" def _get_ls_params( - self, stop: Optional[list[str]] = None, **kwargs: Any + self, stop: list[str] | None = None, **kwargs: Any ) -> LangSmithParams: """Get standard params for tracing.""" params = super()._get_ls_params(stop=stop, **kwargs) @@ -343,9 +343,9 @@ class OllamaLLM(BaseLLM): async def _acreate_generate_stream( self, prompt: str, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, - ) -> AsyncIterator[Union[Mapping[str, Any], str]]: + ) -> AsyncIterator[Mapping[str, Any] | str]: if self._async_client: async for part in await self._async_client.generate( **self._generate_params(prompt, stop=stop, **kwargs) @@ -355,9 +355,9 @@ class OllamaLLM(BaseLLM): def _create_generate_stream( self, prompt: str, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, - ) -> Iterator[Union[Mapping[str, Any], str]]: + ) -> Iterator[Mapping[str, Any] | str]: if self._client: yield from self._client.generate( **self._generate_params(prompt, stop=stop, **kwargs) @@ -366,8 +366,8 @@ class OllamaLLM(BaseLLM): async def _astream_with_aggregation( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, verbose: bool = False, # noqa: FBT002 **kwargs: Any, ) -> GenerationChunk: @@ -408,8 +408,8 @@ class OllamaLLM(BaseLLM): def _stream_with_aggregation( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, verbose: bool = False, # noqa: FBT002 **kwargs: Any, ) -> GenerationChunk: @@ -450,8 +450,8 @@ class OllamaLLM(BaseLLM): def _generate( self, prompts: list[str], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> LLMResult: generations = [] @@ -469,8 +469,8 @@ class OllamaLLM(BaseLLM): async def _agenerate( self, prompts: list[str], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> LLMResult: generations = [] @@ -488,8 +488,8 @@ class OllamaLLM(BaseLLM): def _stream( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: reasoning = kwargs.get("reasoning", self.reasoning) @@ -519,8 +519,8 @@ class OllamaLLM(BaseLLM): async def _astream( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: reasoning = kwargs.get("reasoning", self.reasoning) diff --git a/libs/partners/ollama/pyproject.toml b/libs/partners/ollama/pyproject.toml index 8876a5db4a3..63306aa0421 100644 --- a/libs/partners/ollama/pyproject.toml +++ b/libs/partners/ollama/pyproject.toml @@ -62,8 +62,6 @@ ignore = [ "S112", # Rarely useful "RUF012", # Doesn't play well with Pydantic "SLF001", # Private member access - "UP007", # pyupgrade: non-pep604-annotation-union - "UP045", # pyupgrade: non-pep604-annotation-optional "FIX002", # TODOs "TD002", # TODO authors "TD003", # TODO missing url diff --git a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models.py b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models.py index 235ba54fef0..e68f2e08555 100644 --- a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models.py +++ b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Annotated, Optional +from typing import Annotated from unittest.mock import MagicMock, patch import pytest @@ -120,11 +120,11 @@ def test_structured_output_deeply_nested(model: str) -> None: class Person(BaseModel): """Information about a person.""" - name: Optional[str] = Field(default=None, description="The name of the person") - hair_color: Optional[str] = Field( + name: str | None = Field(default=None, description="The name of the person") + hair_color: str | None = Field( default=None, description="The color of the person's hair if known" ) - height_in_meters: Optional[str] = Field( + height_in_meters: str | None = Field( default=None, description="Height measured in meters" ) @@ -243,7 +243,7 @@ async def test_tool_astreaming(model: str) -> None: ("model", "output_version"), [(REASONING_MODEL_NAME, None), (REASONING_MODEL_NAME, "v1")], ) -def test_agent_loop(model: str, output_version: Optional[str]) -> None: +def test_agent_loop(model: str, output_version: str | None) -> None: """Test agent loop with tool calling and message passing.""" @tool diff --git a/libs/partners/openai/langchain_openai/chat_models/_client_utils.py b/libs/partners/openai/langchain_openai/chat_models/_client_utils.py index 87aa2ae2c80..d77fedb4e61 100644 --- a/libs/partners/openai/langchain_openai/chat_models/_client_utils.py +++ b/libs/partners/openai/langchain_openai/chat_models/_client_utils.py @@ -11,7 +11,7 @@ from __future__ import annotations import asyncio import os from functools import lru_cache -from typing import Any, Optional +from typing import Any import openai @@ -44,7 +44,7 @@ class _AsyncHttpxClientWrapper(openai.DefaultAsyncHttpxClient): def _build_sync_httpx_client( - base_url: Optional[str], timeout: Any + base_url: str | None, timeout: Any ) -> _SyncHttpxClientWrapper: return _SyncHttpxClientWrapper( base_url=base_url @@ -55,7 +55,7 @@ def _build_sync_httpx_client( def _build_async_httpx_client( - base_url: Optional[str], timeout: Any + base_url: str | None, timeout: Any ) -> _AsyncHttpxClientWrapper: return _AsyncHttpxClientWrapper( base_url=base_url @@ -67,20 +67,20 @@ def _build_async_httpx_client( @lru_cache def _cached_sync_httpx_client( - base_url: Optional[str], timeout: Any + base_url: str | None, timeout: Any ) -> _SyncHttpxClientWrapper: return _build_sync_httpx_client(base_url, timeout) @lru_cache def _cached_async_httpx_client( - base_url: Optional[str], timeout: Any + base_url: str | None, timeout: Any ) -> _AsyncHttpxClientWrapper: return _build_async_httpx_client(base_url, timeout) def _get_default_httpx_client( - base_url: Optional[str], timeout: Any + base_url: str | None, timeout: Any ) -> _SyncHttpxClientWrapper: """Get default httpx client. @@ -95,7 +95,7 @@ def _get_default_httpx_client( def _get_default_async_httpx_client( - base_url: Optional[str], timeout: Any + base_url: str | None, timeout: Any ) -> _AsyncHttpxClientWrapper: """Get default httpx client. diff --git a/libs/partners/openai/langchain_openai/chat_models/_compat.py b/libs/partners/openai/langchain_openai/chat_models/_compat.py index def74b408fa..6ff3b932b9b 100644 --- a/libs/partners/openai/langchain_openai/chat_models/_compat.py +++ b/libs/partners/openai/langchain_openai/chat_models/_compat.py @@ -69,7 +69,7 @@ from __future__ import annotations import json from collections.abc import Iterable, Iterator -from typing import Any, Union, cast +from typing import Any, cast from langchain_core.messages import AIMessage, is_data_content_block from langchain_core.messages import content as types @@ -83,7 +83,7 @@ def _convert_to_v03_ai_message( ) -> AIMessage: """Mutate an `AIMessage` to the old-style v0.3 format.""" if isinstance(message.content, list): - new_content: list[Union[dict, str]] = [] + new_content: list[dict | str] = [] for block in message.content: if isinstance(block, dict): if block.get("type") == "reasoning": diff --git a/libs/partners/openai/langchain_openai/chat_models/azure.py b/libs/partners/openai/langchain_openai/chat_models/azure.py index 821742205d7..22dd4328c50 100644 --- a/libs/partners/openai/langchain_openai/chat_models/azure.py +++ b/libs/partners/openai/langchain_openai/chat_models/azure.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging import os from collections.abc import AsyncIterator, Awaitable, Callable, Iterator -from typing import Any, Literal, Optional, TypeVar, Union +from typing import Any, Literal, TypeAlias, TypeVar import openai from langchain_core.language_models import LanguageModelInput @@ -23,8 +23,8 @@ logger = logging.getLogger(__name__) _BM = TypeVar("_BM", bound=BaseModel) -_DictOrPydanticClass = Union[dict[str, Any], type[_BM]] -_DictOrPydantic = Union[dict, _BM] +_DictOrPydanticClass: TypeAlias = dict[str, Any] | type[_BM] | type +_DictOrPydantic: TypeAlias = dict | _BM def _is_pydantic_class(obj: Any) -> bool: @@ -462,7 +462,7 @@ class AzureChatOpenAI(BaseChatOpenAI): """ # noqa: E501 - azure_endpoint: Optional[str] = Field( + azure_endpoint: str | None = Field( default_factory=from_env("AZURE_OPENAI_ENDPOINT", default=None) ) """Your Azure endpoint, including the resource. @@ -471,7 +471,7 @@ class AzureChatOpenAI(BaseChatOpenAI): Example: ``https://example-resource.azure.openai.com/`` """ - deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment") + deployment_name: str | None = Field(default=None, alias="azure_deployment") """A model deployment. If given sets the base client URL to include ``/deployments/{azure_deployment}`` @@ -479,7 +479,7 @@ class AzureChatOpenAI(BaseChatOpenAI): !!! note This means you won't be able to use non-deployment endpoints. """ - openai_api_version: Optional[str] = Field( + openai_api_version: str | None = Field( alias="api_version", default_factory=from_env("OPENAI_API_VERSION", default=None), ) @@ -487,14 +487,14 @@ class AzureChatOpenAI(BaseChatOpenAI): # Check OPENAI_API_KEY for backwards compatibility. # TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using # other forms of azure credentials. - openai_api_key: Optional[SecretStr] = Field( + openai_api_key: SecretStr | None = Field( alias="api_key", default_factory=secret_from_env( ["AZURE_OPENAI_API_KEY", "OPENAI_API_KEY"], default=None ), ) """Automatically inferred from env var ``AZURE_OPENAI_API_KEY`` if not provided.""" - azure_ad_token: Optional[SecretStr] = Field( + azure_ad_token: SecretStr | None = Field( default_factory=secret_from_env("AZURE_OPENAI_AD_TOKEN", default=None) ) """Your Azure Active Directory token. @@ -503,14 +503,14 @@ class AzureChatOpenAI(BaseChatOpenAI): For more, see `this page `__. """ - azure_ad_token_provider: Union[Callable[[], str], None] = None + azure_ad_token_provider: Callable[[], str] | None = None """A function that returns an Azure Active Directory token. Will be invoked on every sync request. For async requests, will be invoked if ``azure_ad_async_token_provider`` is not provided. """ - azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None + azure_ad_async_token_provider: Callable[[], Awaitable[str]] | None = None """A function that returns an Azure Active Directory token. Will be invoked on every async request. @@ -529,7 +529,7 @@ class AzureChatOpenAI(BaseChatOpenAI): correct cost. """ - openai_api_type: Optional[str] = Field( + openai_api_type: str | None = Field( default_factory=from_env("OPENAI_API_TYPE", default="azure") ) """Legacy, for ``openai<1.0.0`` support.""" @@ -539,7 +539,7 @@ class AzureChatOpenAI(BaseChatOpenAI): ``base_url`` or ``azure_endpoint`` and update client params accordingly. """ - model_name: Optional[str] = Field(default=None, alias="model") # type: ignore[assignment] + model_name: str | None = Field(default=None, alias="model") # type: ignore[assignment] """Name of the deployed OpenAI model, e.g. ``'gpt-4o'``, ``'gpt-35-turbo'``, etc. Distinct from the Azure deployment name, which is set by the Azure user. @@ -549,7 +549,7 @@ class AzureChatOpenAI(BaseChatOpenAI): Does NOT affect completion. """ - disabled_params: Optional[dict[str, Any]] = Field(default=None) + disabled_params: dict[str, Any] | None = Field(default=None) """Parameters of the OpenAI client or chat.completions endpoint that should be disabled for the given model. @@ -571,7 +571,7 @@ class AzureChatOpenAI(BaseChatOpenAI): ``'parallel_tools_calls'`` will be disabled. """ - max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens") # type: ignore[assignment] + max_tokens: int | None = Field(default=None, alias="max_completion_tokens") # type: ignore[assignment] """Maximum number of tokens to generate.""" @classmethod @@ -731,7 +731,7 @@ class AzureChatOpenAI(BaseChatOpenAI): return params def _get_ls_params( - self, stop: Optional[list[str]] = None, **kwargs: Any + self, stop: list[str] | None = None, **kwargs: Any ) -> LangSmithParams: """Get the parameters used to invoke the model.""" params = super()._get_ls_params(stop=stop, **kwargs) @@ -749,8 +749,8 @@ class AzureChatOpenAI(BaseChatOpenAI): def _create_chat_result( self, - response: Union[dict, openai.BaseModel], - generation_info: Optional[dict] = None, + response: dict | openai.BaseModel, + generation_info: dict | None = None, ) -> ChatResult: chat_result = super()._create_chat_result(response, generation_info) @@ -790,7 +790,7 @@ class AzureChatOpenAI(BaseChatOpenAI): self, input_: LanguageModelInput, *, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> dict: """Get the request payload, using deployment name for Azure Responses API.""" @@ -825,11 +825,11 @@ class AzureChatOpenAI(BaseChatOpenAI): def with_structured_output( self, - schema: Optional[_DictOrPydanticClass] = None, + schema: _DictOrPydanticClass | None = None, *, method: Literal["function_calling", "json_mode", "json_schema"] = "json_schema", include_raw: bool = False, - strict: Optional[bool] = None, + strict: bool | None = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, _DictOrPydantic]: r"""Model wrapper that returns outputs formatted to match the given schema. diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index b7a2db1d85e..2d4cd988bbf 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -20,9 +20,8 @@ from typing import ( TYPE_CHECKING, Any, Literal, - Optional, + TypeAlias, TypeVar, - Union, cast, ) from urllib.parse import urlparse @@ -207,7 +206,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: def _format_message_content( content: Any, api: Literal["chat/completions", "responses"] = "chat/completions", - role: Optional[str] = None, + role: str | None = None, ) -> Any: """Format message content.""" if content and isinstance(content, list): @@ -301,7 +300,7 @@ def _convert_message_to_dict( if "function_call" in message_dict or "tool_calls" in message_dict: message_dict["content"] = message_dict["content"] or None - audio: Optional[dict[str, Any]] = None + audio: dict[str, Any] | None = None for block in message.content: if ( isinstance(block, dict) @@ -396,8 +395,8 @@ def _convert_delta_to_message_chunk( def _update_token_usage( - overall_token_usage: Union[int, dict], new_usage: Union[int, dict] -) -> Union[int, dict]: + overall_token_usage: int | dict, new_usage: int | dict +) -> int | dict: # Token usage is either ints or dictionaries # `reasoning_tokens` is nested inside `completion_tokens_details` if isinstance(new_usage, int): @@ -449,8 +448,8 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None: _BM = TypeVar("_BM", bound=BaseModel) -_DictOrPydanticClass = Union[dict[str, Any], type[_BM], type] -_DictOrPydantic = Union[dict, _BM] +_DictOrPydanticClass: TypeAlias = dict[str, Any] | type[_BM] | type +_DictOrPydantic: TypeAlias = dict | _BM class BaseChatOpenAI(BaseChatModel): @@ -462,27 +461,27 @@ class BaseChatOpenAI(BaseChatModel): root_async_client: Any = Field(default=None, exclude=True) #: :meta private: model_name: str = Field(default="gpt-3.5-turbo", alias="model") """Model name to use.""" - temperature: Optional[float] = None + temperature: float | None = None """What sampling temperature to use.""" model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - openai_api_key: Optional[SecretStr] = Field( + openai_api_key: SecretStr | None = Field( alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) - openai_api_base: Optional[str] = Field(default=None, alias="base_url") + openai_api_base: str | None = Field(default=None, alias="base_url") """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" # noqa: E501 - openai_organization: Optional[str] = Field(default=None, alias="organization") + openai_organization: str | None = Field(default=None, alias="organization") """Automatically inferred from env var `OPENAI_ORG_ID` if not provided.""" # to support explicit proxy for OpenAI - openai_proxy: Optional[str] = Field( + openai_proxy: str | None = Field( default_factory=from_env("OPENAI_PROXY", default=None) ) - request_timeout: Union[float, tuple[float, float], Any, None] = Field( + request_timeout: float | tuple[float, float] | Any | None = Field( default=None, alias="timeout" ) """Timeout for requests to OpenAI completion API. Can be float, `httpx.Timeout` or `None`.""" - stream_usage: Optional[bool] = None + stream_usage: bool | None = None """Whether to include usage metadata in streaming output. If enabled, an additional message chunk will be generated during the stream including usage metadata. @@ -495,31 +494,31 @@ class BaseChatOpenAI(BaseChatModel): !!! warning "Behavior changed in 0.3.35" Enabled for default base URL and client. """ - max_retries: Optional[int] = None + max_retries: int | None = None """Maximum number of retries to make when generating.""" - presence_penalty: Optional[float] = None + presence_penalty: float | None = None """Penalizes repeated tokens.""" - frequency_penalty: Optional[float] = None + frequency_penalty: float | None = None """Penalizes repeated tokens according to frequency.""" - seed: Optional[int] = None + seed: int | None = None """Seed for generation""" - logprobs: Optional[bool] = None + logprobs: bool | None = None """Whether to return logprobs.""" - top_logprobs: Optional[int] = None + top_logprobs: int | None = None """Number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to true if this parameter is used.""" - logit_bias: Optional[dict[int, int]] = None + logit_bias: dict[int, int] | None = None """Modify the likelihood of specified tokens appearing in the completion.""" streaming: bool = False """Whether to stream the results or not.""" - n: Optional[int] = None + n: int | None = None """Number of chat completions to generate for each prompt.""" - top_p: Optional[float] = None + top_p: float | None = None """Total probability mass of tokens to consider at each step.""" - max_tokens: Optional[int] = Field(default=None) + max_tokens: int | None = Field(default=None) """Maximum number of tokens to generate.""" - reasoning_effort: Optional[str] = None + reasoning_effort: str | None = None """Constrains effort on reasoning for reasoning models. For use with the Chat Completions API. @@ -529,7 +528,7 @@ class BaseChatOpenAI(BaseChatModel): `'high'`. Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response. """ - reasoning: Optional[dict[str, Any]] = None + reasoning: dict[str, Any] | None = None """Reasoning parameters for reasoning models. For use with the Responses API. ```python @@ -541,7 +540,7 @@ class BaseChatOpenAI(BaseChatModel): !!! version-added "Added in version 0.3.24" """ - verbosity: Optional[str] = None + verbosity: str | None = None """Controls the verbosity level of responses for reasoning models. For use with the Responses API. @@ -549,7 +548,7 @@ class BaseChatOpenAI(BaseChatModel): !!! version-added "Added in version 0.3.28" """ - tiktoken_model_name: Optional[str] = None + tiktoken_model_name: str | None = None """The model name to pass to tiktoken when using this class. Tiktoken is used to count the number of tokens in documents to constrain them to be under a certain limit. By default, when set to None, this will @@ -559,20 +558,20 @@ class BaseChatOpenAI(BaseChatModel): when using one of the many model providers that expose an OpenAI-like API but with different models. In those cases, in order to avoid erroring when tiktoken is called, you can specify a model name to use here.""" - default_headers: Union[Mapping[str, str], None] = None - default_query: Union[Mapping[str, object], None] = None + default_headers: Mapping[str, str] | None = None + default_query: Mapping[str, object] | None = None # Configure a custom httpx client. See the # [httpx documentation](https://www.python-httpx.org/api/#client) for more details. - http_client: Union[Any, None] = Field(default=None, exclude=True) + http_client: Any | None = Field(default=None, exclude=True) """Optional `httpx.Client`. Only used for sync invocations. Must specify `http_async_client` as well if you'd like a custom client for async invocations. """ - http_async_client: Union[Any, None] = Field(default=None, exclude=True) + http_async_client: Any | None = Field(default=None, exclude=True) """Optional `httpx.AsyncClient`. Only used for async invocations. Must specify `http_client` as well if you'd like a custom client for sync invocations.""" - stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences") + stop: list[str] | str | None = Field(default=None, alias="stop_sequences") """Default stop sequences.""" - extra_body: Optional[Mapping[str, Any]] = None + extra_body: Mapping[str, Any] | None = None """Optional additional JSON properties to include in the request parameters when making requests to OpenAI compatible APIs, such as vLLM, LM Studio, or other providers. @@ -594,7 +593,7 @@ class BaseChatOpenAI(BaseChatModel): include_response_headers: bool = False """Whether to include response headers in the output message `response_metadata`.""" - disabled_params: Optional[dict[str, Any]] = Field(default=None) + disabled_params: dict[str, Any] | None = Field(default=None) """Parameters of the OpenAI client or `chat.completions` endpoint that should be disabled for the given model. @@ -611,7 +610,7 @@ class BaseChatOpenAI(BaseChatModel): passed in the parameter during invocation. """ - include: Optional[list[str]] = None + include: list[str] | None = None """Additional fields to include in generations from Responses API. Supported values: @@ -625,19 +624,19 @@ class BaseChatOpenAI(BaseChatModel): !!! version-added "Added in version 0.3.24" """ - service_tier: Optional[str] = None + service_tier: str | None = None """Latency tier for request. Options are `'auto'`, `'default'`, or `'flex'`. Relevant for users of OpenAI's scale tier service. """ - store: Optional[bool] = None + store: bool | None = None """If `True`, OpenAI may store response data for future use. Defaults to `True` for the Responses API and `False` for the Chat Completions API. !!! version-added "Added in version 0.3.24" """ - truncation: Optional[str] = None + truncation: str | None = None """Truncation strategy (Responses API). Can be `'auto'` or `'disabled'` (default). If `'auto'`, model may drop input items from the middle of the message sequence to fit the context window. @@ -676,7 +675,7 @@ class BaseChatOpenAI(BaseChatModel): !!! version-added "Added in version 0.3.26" """ - use_responses_api: Optional[bool] = None + use_responses_api: bool | None = None """Whether to use the Responses API instead of the Chat API. If not specified then will be inferred based on invocation params. @@ -684,7 +683,7 @@ class BaseChatOpenAI(BaseChatModel): !!! version-added "Added in version 0.3.9" """ - output_version: Optional[str] = Field( + output_version: str | None = Field( default_factory=from_env("LC_OUTPUT_VERSION", default=None) ) """Version of `AIMessage` output format to use. @@ -876,7 +875,7 @@ class BaseChatOpenAI(BaseChatModel): **self.model_kwargs, } - def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: + def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict: overall_token_usage: dict = {} system_fingerprint = None for output in llm_outputs: @@ -905,8 +904,8 @@ class BaseChatOpenAI(BaseChatModel): self, chunk: dict, default_chunk_class: type, - base_generation_info: Optional[dict], - ) -> Optional[ChatGenerationChunk]: + base_generation_info: dict | None, + ) -> ChatGenerationChunk | None: if chunk.get("type") == "content.delta": # From beta.chat.completions.stream return None token_usage = chunk.get("usage") @@ -916,7 +915,7 @@ class BaseChatOpenAI(BaseChatModel): or chunk.get("chunk", {}).get("choices", []) ) - usage_metadata: Optional[UsageMetadata] = ( + usage_metadata: UsageMetadata | None = ( _create_usage_metadata(token_usage) if token_usage else None ) if len(choices) == 0: @@ -966,8 +965,8 @@ class BaseChatOpenAI(BaseChatModel): def _stream_responses( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: kwargs["stream"] = True @@ -1019,8 +1018,8 @@ class BaseChatOpenAI(BaseChatModel): async def _astream_responses( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: kwargs["stream"] = True @@ -1072,7 +1071,7 @@ class BaseChatOpenAI(BaseChatModel): yield generation_chunk def _should_stream_usage( - self, stream_usage: Optional[bool] = None, **kwargs: Any + self, stream_usage: bool | None = None, **kwargs: Any ) -> bool: """Determine whether to include usage metadata in streaming output. @@ -1093,10 +1092,10 @@ class BaseChatOpenAI(BaseChatModel): def _stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, *, - stream_usage: Optional[bool] = None, + stream_usage: bool | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: kwargs["stream"] = True @@ -1163,8 +1162,8 @@ class BaseChatOpenAI(BaseChatModel): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: payload = self._get_request_payload(messages, stop=stop, **kwargs) @@ -1233,7 +1232,7 @@ class BaseChatOpenAI(BaseChatModel): self, input_: LanguageModelInput, *, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> dict: messages = self._convert_input(input_).to_messages() @@ -1262,8 +1261,8 @@ class BaseChatOpenAI(BaseChatModel): def _create_chat_result( self, - response: Union[dict, openai.BaseModel], - generation_info: Optional[dict] = None, + response: dict | openai.BaseModel, + generation_info: dict | None = None, ) -> ChatResult: generations = [] @@ -1329,10 +1328,10 @@ class BaseChatOpenAI(BaseChatModel): async def _astream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, *, - stream_usage: Optional[bool] = None, + stream_usage: bool | None = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: kwargs["stream"] = True @@ -1403,8 +1402,8 @@ class BaseChatOpenAI(BaseChatModel): async def _agenerate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: payload = self._get_request_payload(messages, stop=stop, **kwargs) @@ -1468,7 +1467,7 @@ class BaseChatOpenAI(BaseChatModel): return {"model_name": self.model_name, **self._default_params} def _get_invocation_params( - self, stop: Optional[list[str]] = None, **kwargs: Any + self, stop: list[str] | None = None, **kwargs: Any ) -> dict[str, Any]: """Get the parameters used to invoke the model.""" params = { @@ -1489,7 +1488,7 @@ class BaseChatOpenAI(BaseChatModel): return params def _get_ls_params( - self, stop: Optional[list[str]] = None, **kwargs: Any + self, stop: list[str] | None = None, **kwargs: Any ) -> LangSmithParams: """Get standard params for tracing.""" params = self._get_invocation_params(stop=stop, **kwargs) @@ -1546,9 +1545,7 @@ class BaseChatOpenAI(BaseChatModel): def get_num_tokens_from_messages( self, messages: Sequence[BaseMessage], - tools: Optional[ - Sequence[Union[dict[str, Any], type, Callable, BaseTool]] - ] = None, + tools: Sequence[dict[str, Any] | type | Callable | BaseTool] | None = None, ) -> int: """Calculate num tokens for `gpt-3.5-turbo` and `gpt-4` with `tiktoken` package. @@ -1642,13 +1639,11 @@ class BaseChatOpenAI(BaseChatModel): def bind_tools( self, - tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]], + tools: Sequence[dict[str, Any] | type | Callable | BaseTool], *, - tool_choice: Optional[ - Union[dict, str, Literal["auto", "none", "required", "any"], bool] # noqa: PYI051 - ] = None, - strict: Optional[bool] = None, - parallel_tool_calls: Optional[bool] = None, + tool_choice: dict | str | bool | None = None, + strict: bool | None = None, + parallel_tool_calls: bool | None = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, AIMessage]: """Bind tool-like objects to this chat model. @@ -1723,14 +1718,14 @@ class BaseChatOpenAI(BaseChatModel): def with_structured_output( self, - schema: Optional[_DictOrPydanticClass] = None, + schema: _DictOrPydanticClass | None = None, *, method: Literal[ "function_calling", "json_mode", "json_schema" ] = "function_calling", include_raw: bool = False, - strict: Optional[bool] = None, - tools: Optional[list] = None, + strict: bool | None = None, + tools: list | None = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, _DictOrPydantic]: """Model wrapper that returns outputs formatted to match the given schema. @@ -2740,7 +2735,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] cached tokens. """ # noqa: E501 - max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens") + max_tokens: int | None = Field(default=None, alias="max_completion_tokens") """Maximum number of tokens to generate.""" @property @@ -2787,7 +2782,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] self, input_: LanguageModelInput, *, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> dict: payload = super()._get_request_payload(input_, stop=stop, **kwargs) @@ -2822,11 +2817,11 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] def with_structured_output( self, - schema: Optional[_DictOrPydanticClass] = None, + schema: _DictOrPydanticClass | None = None, *, method: Literal["function_calling", "json_mode", "json_schema"] = "json_schema", include_raw: bool = False, - strict: Optional[bool] = None, + strict: bool | None = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, _DictOrPydantic]: r"""Model wrapper that returns outputs formatted to match the given schema. @@ -3239,7 +3234,7 @@ def _lc_invalid_tool_call_to_openai_tool_call( } -def _url_to_size(image_source: str) -> Optional[tuple[int, int]]: +def _url_to_size(image_source: str) -> tuple[int, int] | None: try: from PIL import Image # type: ignore[import] except ImportError: @@ -3311,8 +3306,8 @@ def _resize(width: int, height: int) -> tuple[int, int]: def _convert_to_openai_response_format( - schema: Union[dict[str, Any], type], *, strict: Optional[bool] = None -) -> Union[dict, TypeBaseModel]: + schema: dict[str, Any] | type, *, strict: bool | None = None +) -> dict | TypeBaseModel: if isinstance(schema, type) and is_basemodel_subclass(schema): return schema @@ -3351,7 +3346,7 @@ def _convert_to_openai_response_format( def _oai_structured_outputs_parser( ai_msg: AIMessage, schema: type[_BM] -) -> Optional[PydanticBaseModel]: +) -> PydanticBaseModel | None: if parsed := ai_msg.additional_kwargs.get("parsed"): if isinstance(parsed, dict): return schema(**parsed) @@ -3471,7 +3466,7 @@ def _use_responses_api(payload: dict) -> bool: def _get_last_messages( messages: Sequence[BaseMessage], -) -> tuple[Sequence[BaseMessage], Optional[str]]: +) -> tuple[Sequence[BaseMessage], str | None]: """Get the last part of the conversation after the last `AIMessage` with an `id`. Return: @@ -3620,7 +3615,7 @@ def _convert_chat_completions_blocks_to_responses( return block -def _ensure_valid_tool_message_content(tool_output: Any) -> Union[str, list[dict]]: +def _ensure_valid_tool_message_content(tool_output: Any) -> str | list[dict]: if isinstance(tool_output, str): return tool_output if isinstance(tool_output, list) and all( @@ -3645,8 +3640,8 @@ def _ensure_valid_tool_message_content(tool_output: Any) -> Union[str, list[dict def _make_computer_call_output_from_message( message: ToolMessage, -) -> Optional[dict[str, Any]]: - computer_call_output: Optional[dict[str, Any]] = None +) -> dict[str, Any] | None: + computer_call_output: dict[str, Any] | None = None if isinstance(message.content, list): for block in message.content: if ( @@ -3685,7 +3680,7 @@ def _make_computer_call_output_from_message( return computer_call_output -def _make_custom_tool_output_from_message(message: ToolMessage) -> Optional[dict]: +def _make_custom_tool_output_from_message(message: ToolMessage) -> dict | None: custom_tool_output = None for block in message.content: if isinstance(block, dict) and block.get("type") == "custom_tool_call_output": @@ -3902,9 +3897,9 @@ def _get_output_text(response: Response) -> str: def _construct_lc_result_from_responses_api( response: Response, - schema: Optional[type[_BM]] = None, - metadata: Optional[dict] = None, - output_version: Optional[str] = None, + schema: type[_BM] | None = None, + metadata: dict | None = None, + output_version: str | None = None, ) -> ChatResult: """Construct `ChatResponse` from OpenAI Response API response.""" if response.error: @@ -4072,12 +4067,12 @@ def _convert_responses_chunk_to_generation_chunk( current_index: int, # index in content current_output_index: int, # index in Response output current_sub_index: int, # index of content block in output item - schema: Optional[type[_BM]] = None, - metadata: Optional[dict] = None, + schema: type[_BM] | None = None, + metadata: dict | None = None, has_reasoning: bool = False, - output_version: Optional[str] = None, -) -> tuple[int, int, int, Optional[ChatGenerationChunk]]: - def _advance(output_idx: int, sub_idx: Optional[int] = None) -> None: + output_version: str | None = None, +) -> tuple[int, int, int, ChatGenerationChunk | None]: + def _advance(output_idx: int, sub_idx: int | None = None) -> None: """Advance indexes tracked during streaming. Example: we stream a response item of the form: @@ -4133,7 +4128,7 @@ def _convert_responses_chunk_to_generation_chunk( response_metadata = metadata or {} response_metadata["model_provider"] = "openai" usage_metadata = None - chunk_position: Optional[Literal["last"]] = None + chunk_position: Literal["last"] | None = None id = None if chunk.type == "response.output_text.delta": _advance(chunk.output_index, chunk.content_index) diff --git a/libs/partners/openai/langchain_openai/embeddings/azure.py b/libs/partners/openai/langchain_openai/embeddings/azure.py index f0cb59a2bbe..c74d70fe8e6 100644 --- a/libs/partners/openai/langchain_openai/embeddings/azure.py +++ b/libs/partners/openai/langchain_openai/embeddings/azure.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Awaitable, Callable -from typing import Optional, Union, cast +from typing import cast import openai from langchain_core.utils import from_env, secret_from_env @@ -102,7 +102,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): # type: ignore[override] """ # noqa: E501 - azure_endpoint: Optional[str] = Field( + azure_endpoint: str | None = Field( default_factory=from_env("AZURE_OPENAI_ENDPOINT", default=None) ) """Your Azure endpoint, including the resource. @@ -111,7 +111,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): # type: ignore[override] Example: `https://example-resource.azure.openai.com/` """ - deployment: Optional[str] = Field(default=None, alias="azure_deployment") + deployment: str | None = Field(default=None, alias="azure_deployment") """A model deployment. If given sets the base client URL to include `/deployments/{azure_deployment}`. @@ -123,14 +123,14 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): # type: ignore[override] # Check OPENAI_KEY for backwards compatibility. # TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using # other forms of azure credentials. - openai_api_key: Optional[SecretStr] = Field( + openai_api_key: SecretStr | None = Field( alias="api_key", default_factory=secret_from_env( ["AZURE_OPENAI_API_KEY", "OPENAI_API_KEY"], default=None ), ) """Automatically inferred from env var ``AZURE_OPENAI_API_KEY`` if not provided.""" - openai_api_version: Optional[str] = Field( + openai_api_version: str | None = Field( default_factory=from_env("OPENAI_API_VERSION", default="2023-05-15"), alias="api_version", ) @@ -139,7 +139,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): # type: ignore[override] Set to ``'2023-05-15'`` by default if env variable ``OPENAI_API_VERSION`` is not set. """ - azure_ad_token: Optional[SecretStr] = Field( + azure_ad_token: SecretStr | None = Field( default_factory=secret_from_env("AZURE_OPENAI_AD_TOKEN", default=None) ) """Your Azure Active Directory token. @@ -148,18 +148,18 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): # type: ignore[override] `For more, see this page. `__ """ - azure_ad_token_provider: Union[Callable[[], str], None] = None + azure_ad_token_provider: Callable[[], str] | None = None """A function that returns an Azure Active Directory token. Will be invoked on every sync request. For async requests, will be invoked if ``azure_ad_async_token_provider`` is not provided. """ - azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None + azure_ad_async_token_provider: Callable[[], Awaitable[str]] | None = None """A function that returns an Azure Active Directory token. Will be invoked on every async request. """ - openai_api_type: Optional[str] = Field( + openai_api_type: str | None = Field( default_factory=from_env("OPENAI_API_TYPE", default="azure") ) validate_base_url: bool = True diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index 360057574e8..4c5e8ecfae3 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging import warnings from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, Union, cast +from typing import Any, Literal, cast import openai import tiktoken @@ -20,11 +20,11 @@ logger = logging.getLogger(__name__) def _process_batched_chunked_embeddings( num_texts: int, - tokens: list[Union[list[int], str]], + tokens: list[list[int] | str], batched_embeddings: list[list[float]], indices: list[int], skip_empty: bool, -) -> list[Optional[list[float]]]: +) -> list[list[float] | None]: # for each text, this is the list of embeddings (list of list of floats) # corresponding to the chunks of the text results: list[list[list[float]]] = [[] for _ in range(num_texts)] @@ -41,7 +41,7 @@ def _process_batched_chunked_embeddings( num_tokens_in_batch[indices[i]].append(len(tokens[i])) # for each text, this is the final embedding - embeddings: list[Optional[list[float]]] = [] + embeddings: list[list[float] | None] = [] for i in range(num_texts): # an embedding for each chunk _result: list[list[float]] = results[i] @@ -164,53 +164,53 @@ class OpenAIEmbeddings(BaseModel, Embeddings): client: Any = Field(default=None, exclude=True) #: :meta private: async_client: Any = Field(default=None, exclude=True) #: :meta private: model: str = "text-embedding-ada-002" - dimensions: Optional[int] = None + dimensions: int | None = None """The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models. """ # to support Azure OpenAI Service custom deployment names - deployment: Optional[str] = model + deployment: str | None = model # TODO: Move to AzureOpenAIEmbeddings. - openai_api_version: Optional[str] = Field( + openai_api_version: str | None = Field( default_factory=from_env("OPENAI_API_VERSION", default=None), alias="api_version", ) """Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" # to support Azure OpenAI Service custom endpoints - openai_api_base: Optional[str] = Field( + openai_api_base: str | None = Field( alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None) ) """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" # to support Azure OpenAI Service custom endpoints - openai_api_type: Optional[str] = Field( + openai_api_type: str | None = Field( default_factory=from_env("OPENAI_API_TYPE", default=None) ) # to support explicit proxy for OpenAI - openai_proxy: Optional[str] = Field( + openai_proxy: str | None = Field( default_factory=from_env("OPENAI_PROXY", default=None) ) embedding_ctx_length: int = 8191 """The maximum number of tokens to embed at once.""" - openai_api_key: Optional[SecretStr] = Field( + openai_api_key: SecretStr | None = Field( alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) """Automatically inferred from env var ``OPENAI_API_KEY`` if not provided.""" - openai_organization: Optional[str] = Field( + openai_organization: str | None = Field( alias="organization", default_factory=from_env( ["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None ), ) """Automatically inferred from env var ``OPENAI_ORG_ID`` if not provided.""" - allowed_special: Union[Literal["all"], set[str], None] = None - disallowed_special: Union[Literal["all"], set[str], Sequence[str], None] = None + allowed_special: Literal["all"] | set[str] | None = None + disallowed_special: Literal["all"] | set[str] | Sequence[str] | None = None chunk_size: int = 1000 """Maximum number of texts to embed in each batch""" max_retries: int = 2 """Maximum number of retries to make when generating.""" - request_timeout: Optional[Union[float, tuple[float, float], Any]] = Field( + request_timeout: float | tuple[float, float] | Any | None = Field( default=None, alias="timeout" ) """Timeout for requests to OpenAI completion API. Can be float, ``httpx.Timeout`` or @@ -219,7 +219,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): tiktoken_enabled: bool = True """Set this to False for non-OpenAI implementations of the embeddings API, e.g. the ``--extensions openai`` extension for ``text-generation-webui``""" - tiktoken_model_name: Optional[str] = None + tiktoken_model_name: str | None = None """The model name to pass to tiktoken when using this class. Tiktoken is used to count the number of tokens in documents to constrain them to be under a certain limit. By default, when set to None, this will @@ -236,20 +236,20 @@ class OpenAIEmbeddings(BaseModel, Embeddings): skip_empty: bool = False """Whether to skip empty strings when embedding or raise an error. Defaults to not skipping.""" - default_headers: Union[Mapping[str, str], None] = None - default_query: Union[Mapping[str, object], None] = None + default_headers: Mapping[str, str] | None = None + default_query: Mapping[str, object] | None = None # Configure a custom httpx client. See the # [httpx documentation](https://www.python-httpx.org/api/#client) for more details. retry_min_seconds: int = 4 """Min number of seconds to wait between retries""" retry_max_seconds: int = 20 """Max number of seconds to wait between retries""" - http_client: Union[Any, None] = None + http_client: Any | None = None """Optional ``httpx.Client``. Only used for sync invocations. Must specify ``http_async_client`` as well if you'd like a custom client for async invocations. """ - http_async_client: Union[Any, None] = None + http_async_client: Any | None = None """Optional ``httpx.AsyncClient``. Only used for async invocations. Must specify ``http_client`` as well if you'd like a custom client for sync invocations.""" check_embedding_ctx_length: bool = True @@ -359,7 +359,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): def _tokenize( self, texts: list[str], chunk_size: int - ) -> tuple[Iterable[int], list[Union[list[int], str]], list[int]]: + ) -> tuple[Iterable[int], list[list[int] | str], list[int]]: """Take the input `texts` and `chunk_size` and return 3 iterables as a tuple. We have `batches`, where batches are sets of individual texts @@ -379,7 +379,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): indices: An iterable of the same length as `tokens` that maps each token-array to the index of the original text in `texts`. """ - tokens: list[Union[list[int], str]] = [] + tokens: list[list[int] | str] = [] indices: list[int] = [] model_name = self.tiktoken_model_name or self.model @@ -460,7 +460,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): texts: list[str], *, engine: str, - chunk_size: Optional[int] = None, + chunk_size: int | None = None, **kwargs: Any, ) -> list[list[float]]: """Generate length-safe embeddings for a list of texts. @@ -492,7 +492,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): embeddings = _process_batched_chunked_embeddings( len(texts), tokens, batched_embeddings, indices, self.skip_empty ) - _cached_empty_embedding: Optional[list[float]] = None + _cached_empty_embedding: list[float] | None = None def empty_embedding() -> list[float]: nonlocal _cached_empty_embedding @@ -512,7 +512,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): texts: list[str], *, engine: str, - chunk_size: Optional[int] = None, + chunk_size: int | None = None, **kwargs: Any, ) -> list[list[float]]: """Asynchronously generate length-safe embeddings for a list of texts. @@ -547,7 +547,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): embeddings = _process_batched_chunked_embeddings( len(texts), tokens, batched_embeddings, indices, self.skip_empty ) - _cached_empty_embedding: Optional[list[float]] = None + _cached_empty_embedding: list[float] | None = None async def empty_embedding() -> list[float]: nonlocal _cached_empty_embedding @@ -563,7 +563,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): return [e if e is not None else await empty_embedding() for e in embeddings] def embed_documents( - self, texts: list[str], chunk_size: Optional[int] = None, **kwargs: Any + self, texts: list[str], chunk_size: int | None = None, **kwargs: Any ) -> list[list[float]]: """Call out to OpenAI's embedding endpoint for embedding search docs. @@ -597,7 +597,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): ) async def aembed_documents( - self, texts: list[str], chunk_size: Optional[int] = None, **kwargs: Any + self, texts: list[str], chunk_size: int | None = None, **kwargs: Any ) -> list[list[float]]: """Call out to OpenAI's embedding endpoint async for embedding search docs. diff --git a/libs/partners/openai/langchain_openai/llms/azure.py b/libs/partners/openai/langchain_openai/llms/azure.py index bf4c09b68af..8fd9bc6d36c 100644 --- a/libs/partners/openai/langchain_openai/llms/azure.py +++ b/libs/partners/openai/langchain_openai/llms/azure.py @@ -4,7 +4,7 @@ from __future__ import annotations import logging from collections.abc import Awaitable, Callable, Mapping -from typing import Any, Optional, Union, cast +from typing import Any, cast import openai from langchain_core.language_models import LangSmithParams @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) class AzureOpenAI(BaseOpenAI): """Azure-specific OpenAI large language models. - To use, you should have the ``openai`` python package installed, and the + To use, you should have the `openai` python package installed, and the environment variable ``OPENAI_API_KEY`` set with your API key. Any parameters that are valid to be passed to the openai.create call can be passed @@ -35,7 +35,7 @@ class AzureOpenAI(BaseOpenAI): """ - azure_endpoint: Optional[str] = Field( + azure_endpoint: str | None = Field( default_factory=from_env("AZURE_OPENAI_ENDPOINT", default=None) ) """Your Azure endpoint, including the resource. @@ -44,7 +44,7 @@ class AzureOpenAI(BaseOpenAI): Example: ``'https://example-resource.azure.openai.com/'`` """ - deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment") + deployment_name: str | None = Field(default=None, alias="azure_deployment") """A model deployment. If given sets the base client URL to include `/deployments/{azure_deployment}`. @@ -53,7 +53,7 @@ class AzureOpenAI(BaseOpenAI): This means you won't be able to use non-deployment endpoints. """ - openai_api_version: Optional[str] = Field( + openai_api_version: str | None = Field( alias="api_version", default_factory=from_env("OPENAI_API_VERSION", default=None), ) @@ -61,13 +61,13 @@ class AzureOpenAI(BaseOpenAI): # Check OPENAI_KEY for backwards compatibility. # TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using # other forms of azure credentials. - openai_api_key: Optional[SecretStr] = Field( + openai_api_key: SecretStr | None = Field( alias="api_key", default_factory=secret_from_env( ["AZURE_OPENAI_API_KEY", "OPENAI_API_KEY"], default=None ), ) - azure_ad_token: Optional[SecretStr] = Field( + azure_ad_token: SecretStr | None = Field( default_factory=secret_from_env("AZURE_OPENAI_AD_TOKEN", default=None) ) """Your Azure Active Directory token. @@ -76,18 +76,18 @@ class AzureOpenAI(BaseOpenAI): `For more, see this page .`__ """ - azure_ad_token_provider: Union[Callable[[], str], None] = None + azure_ad_token_provider: Callable[[], str] | None = None """A function that returns an Azure Active Directory token. Will be invoked on every sync request. For async requests, will be invoked if ``azure_ad_async_token_provider`` is not provided. """ - azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None + azure_ad_async_token_provider: Callable[[], Awaitable[str]] | None = None """A function that returns an Azure Active Directory token. Will be invoked on every async request. """ - openai_api_type: Optional[str] = Field( + openai_api_type: str | None = Field( default_factory=from_env("OPENAI_API_TYPE", default="azure") ) """Legacy, for ``openai<1.0.0`` support.""" @@ -205,7 +205,7 @@ class AzureOpenAI(BaseOpenAI): return {**openai_params, **super()._invocation_params} def _get_ls_params( - self, stop: Optional[list[str]] = None, **kwargs: Any + self, stop: list[str] | None = None, **kwargs: Any ) -> LangSmithParams: """Get standard params for tracing.""" params = super()._get_ls_params(stop=stop, **kwargs) diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index cf192e98d68..233afeddf71 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging import sys from collections.abc import AsyncIterator, Collection, Iterator, Mapping -from typing import Any, Literal, Optional, Union +from typing import Any, Literal import openai import tiktoken @@ -188,16 +188,16 @@ class BaseOpenAI(BaseLLM): """Generates best_of completions server-side and returns the "best".""" model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - openai_api_key: Optional[SecretStr] = Field( + openai_api_key: SecretStr | None = Field( alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) """Automatically inferred from env var ``OPENAI_API_KEY`` if not provided.""" - openai_api_base: Optional[str] = Field( + openai_api_base: str | None = Field( alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None) ) """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" - openai_organization: Optional[str] = Field( + openai_organization: str | None = Field( alias="organization", default_factory=from_env( ["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None @@ -205,32 +205,32 @@ class BaseOpenAI(BaseLLM): ) """Automatically inferred from env var ``OPENAI_ORG_ID`` if not provided.""" # to support explicit proxy for OpenAI - openai_proxy: Optional[str] = Field( + openai_proxy: str | None = Field( default_factory=from_env("OPENAI_PROXY", default=None) ) batch_size: int = 20 """Batch size to use when passing multiple documents to generate.""" - request_timeout: Union[float, tuple[float, float], Any, None] = Field( + request_timeout: float | tuple[float, float] | Any | None = Field( default=None, alias="timeout" ) """Timeout for requests to OpenAI completion API. Can be float, ``httpx.Timeout`` or None.""" - logit_bias: Optional[dict[str, float]] = None + logit_bias: dict[str, float] | None = None """Adjust the probability of specific tokens being generated.""" max_retries: int = 2 """Maximum number of retries to make when generating.""" - seed: Optional[int] = None + seed: int | None = None """Seed for generation""" - logprobs: Optional[int] = None + logprobs: int | None = None """Include the log probabilities on the logprobs most likely output tokens, as well the chosen tokens.""" streaming: bool = False """Whether to stream the results or not.""" - allowed_special: Union[Literal["all"], set[str]] = set() + allowed_special: Literal["all"] | set[str] = set() """Set of special tokens that are allowed。""" - disallowed_special: Union[Literal["all"], Collection[str]] = "all" + disallowed_special: Literal["all"] | Collection[str] = "all" """Set of special tokens that are not allowed。""" - tiktoken_model_name: Optional[str] = None + tiktoken_model_name: str | None = None """The model name to pass to tiktoken when using this class. Tiktoken is used to count the number of tokens in documents to constrain them to be under a certain limit. By default, when set to None, this will @@ -240,19 +240,19 @@ class BaseOpenAI(BaseLLM): when using one of the many model providers that expose an OpenAI-like API but with different models. In those cases, in order to avoid erroring when tiktoken is called, you can specify a model name to use here.""" - default_headers: Union[Mapping[str, str], None] = None - default_query: Union[Mapping[str, object], None] = None + default_headers: Mapping[str, str] | None = None + default_query: Mapping[str, object] | None = None # Configure a custom httpx client. See the # [httpx documentation](https://www.python-httpx.org/api/#client) for more details. - http_client: Union[Any, None] = None + http_client: Any | None = None """Optional ``httpx.Client``. Only used for sync invocations. Must specify ``http_async_client`` as well if you'd like a custom client for async invocations. """ - http_async_client: Union[Any, None] = None + http_async_client: Any | None = None """Optional ``httpx.AsyncClient``. Only used for async invocations. Must specify ``http_client`` as well if you'd like a custom client for sync invocations.""" - extra_body: Optional[Mapping[str, Any]] = None + extra_body: Mapping[str, Any] | None = None """Optional additional JSON properties to include in the request parameters when making requests to OpenAI compatible APIs, such as vLLM.""" @@ -333,8 +333,8 @@ class BaseOpenAI(BaseLLM): def _stream( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: params = {**self._invocation_params, **kwargs, "stream": True} @@ -360,8 +360,8 @@ class BaseOpenAI(BaseLLM): async def _astream( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: params = {**self._invocation_params, **kwargs, "stream": True} @@ -389,8 +389,8 @@ class BaseOpenAI(BaseLLM): def _generate( self, prompts: list[str], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> LLMResult: """Call out to OpenAI's endpoint with k unique prompts. @@ -418,14 +418,14 @@ class BaseOpenAI(BaseLLM): # Get the token usage from the response. # Includes prompt, completion, and total tokens used. _keys = {"completion_tokens", "prompt_tokens", "total_tokens"} - system_fingerprint: Optional[str] = None + system_fingerprint: str | None = None for _prompts in sub_prompts: if self.streaming: if len(_prompts) > 1: msg = "Cannot stream results with multiple prompts." raise ValueError(msg) - generation: Optional[GenerationChunk] = None + generation: GenerationChunk | None = None for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs): if generation is None: generation = chunk @@ -474,8 +474,8 @@ class BaseOpenAI(BaseLLM): async def _agenerate( self, prompts: list[str], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> LLMResult: """Call out to OpenAI's endpoint async with k unique prompts.""" @@ -487,14 +487,14 @@ class BaseOpenAI(BaseLLM): # Get the token usage from the response. # Includes prompt, completion, and total tokens used. _keys = {"completion_tokens", "prompt_tokens", "total_tokens"} - system_fingerprint: Optional[str] = None + system_fingerprint: str | None = None for _prompts in sub_prompts: if self.streaming: if len(_prompts) > 1: msg = "Cannot stream results with multiple prompts." raise ValueError(msg) - generation: Optional[GenerationChunk] = None + generation: GenerationChunk | None = None async for chunk in self._astream( _prompts[0], stop, run_manager, **kwargs ): @@ -534,7 +534,7 @@ class BaseOpenAI(BaseLLM): self, params: dict[str, Any], prompts: list[str], - stop: Optional[list[str]] = None, + stop: list[str] | None = None, ) -> list[list[str]]: """Get the sub prompts for llm call.""" if stop is not None: @@ -556,7 +556,7 @@ class BaseOpenAI(BaseLLM): params: dict[str, Any], token_usage: dict[str, int], *, - system_fingerprint: Optional[str] = None, + system_fingerprint: str | None = None, ) -> LLMResult: """Create the LLMResult from the choices and prompts.""" generations = [] diff --git a/libs/partners/openai/pyproject.toml b/libs/partners/openai/pyproject.toml index 50a9dfaba70..124d6e17b3f 100644 --- a/libs/partners/openai/pyproject.toml +++ b/libs/partners/openai/pyproject.toml @@ -76,8 +76,6 @@ ignore = [ "COM812", # Messes with the formatter "ISC001", # Messes with the formatter "PERF203", # Rarely useful - "UP007", # non-pep604-annotation-union - "UP045", # non-pep604-annotation-optional "SIM105", # Rarely useful "FIX", # TODOs "TD", # TODOs diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py index b8586f03a30..da245c1c025 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py @@ -4,7 +4,7 @@ from __future__ import annotations import json import os -from typing import Any, Optional +from typing import Any import pytest from langchain_core.callbacks import CallbackManager @@ -164,7 +164,7 @@ async def test_async_chat_openai_streaming() -> None: @pytest.mark.scheduled def test_openai_streaming(llm: AzureChatOpenAI) -> None: """Test streaming tokens from OpenAI.""" - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream("I'm Pickle Rick"): assert isinstance(chunk.content, str) full = chunk if full is None else full + chunk @@ -176,7 +176,7 @@ def test_openai_streaming(llm: AzureChatOpenAI) -> None: async def test_openai_astream(llm: AzureChatOpenAI) -> None: """Test streaming tokens from OpenAI.""" - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None async for chunk in llm.astream("I'm Pickle Rick"): assert isinstance(chunk.content, str) full = chunk if full is None else full + chunk @@ -239,7 +239,7 @@ def test_json_mode(llm: AzureChatOpenAI) -> None: assert json.loads(response.content) == {"a": 1} # Test streaming - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream( "Return this as json: {'a': 1}", response_format={"type": "json_object"} ): @@ -257,7 +257,7 @@ async def test_json_mode_async(llm: AzureChatOpenAI) -> None: assert json.loads(response.content) == {"a": 1} # Test streaming - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None async for chunk in llm.astream( "Return this as json: {'a': 1}", response_format={"type": "json_object"} ): @@ -272,7 +272,7 @@ class Foo(BaseModel): def test_stream_response_format(llm: AzureChatOpenAI) -> None: - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None chunks = [] for chunk in llm.stream("how are ya", response_format=Foo): chunks.append(chunk) @@ -287,7 +287,7 @@ def test_stream_response_format(llm: AzureChatOpenAI) -> None: async def test_astream_response_format(llm: AzureChatOpenAI) -> None: - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None chunks = [] async for chunk in llm.astream("how are ya", response_format=Foo): chunks.append(chunk) diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 1595599d713..3aad2f0988d 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -5,7 +5,7 @@ import json from collections.abc import AsyncIterator from pathlib import Path from textwrap import dedent -from typing import Any, Literal, Optional, cast +from typing import Any, Literal, cast import httpx import openai @@ -202,7 +202,7 @@ def test_stream() -> None: """Test streaming tokens from OpenAI.""" llm = ChatOpenAI(model="gpt-4.1-mini") - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream("I'm Pickle Rick"): assert isinstance(chunk.content, str) full = chunk if full is None else full + chunk @@ -211,7 +211,7 @@ def test_stream() -> None: assert full.response_metadata.get("model_name") is not None # check token usage - aggregate: Optional[BaseMessageChunk] = None + aggregate: BaseMessageChunk | None = None chunks_with_token_counts = 0 chunks_with_response_metadata = 0 for chunk in llm.stream("Hello"): @@ -242,7 +242,7 @@ async def test_astream() -> None: """Test streaming tokens from OpenAI.""" async def _test_stream(stream: AsyncIterator, expect_usage: bool) -> None: - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None chunks_with_token_counts = 0 chunks_with_response_metadata = 0 async for chunk in stream: @@ -357,7 +357,7 @@ async def test_async_response_metadata() -> None: def test_response_metadata_streaming() -> None: llm = ChatOpenAI() - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream("I'm Pickle Rick", logprobs=True): assert isinstance(chunk.content, str) full = chunk if full is None else full + chunk @@ -370,7 +370,7 @@ def test_response_metadata_streaming() -> None: async def test_async_response_metadata_streaming() -> None: llm = ChatOpenAI() - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None async for chunk in llm.astream("I'm Pickle Rick", logprobs=True): assert isinstance(chunk.content, str) full = chunk if full is None else full + chunk @@ -564,7 +564,7 @@ def test_openai_response_headers(use_responses_api: bool) -> None: assert "content-type" in headers # Stream - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in chat_openai.stream(query, max_tokens=MAX_TOKEN_COUNT): # type: ignore[call-arg] full = chunk if full is None else full + chunk assert isinstance(full, AIMessage) @@ -588,7 +588,7 @@ async def test_openai_response_headers_async(use_responses_api: bool) -> None: assert "content-type" in headers # Stream - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None async for chunk in chat_openai.astream(query, max_tokens=MAX_TOKEN_COUNT): # type: ignore[call-arg] full = chunk if full is None else full + chunk assert isinstance(full, AIMessage) @@ -672,7 +672,7 @@ def test_tool_calling_strict(use_responses_api: bool) -> None: class magic_function_notrequired_arg(BaseModel): # noqa: N801 """Applies a magic function to an input.""" - input: Optional[int] = Field(default=None) + input: int | None = Field(default=None) model = ChatOpenAI( model="gpt-5-nano", temperature=0, use_responses_api=use_responses_api @@ -694,7 +694,7 @@ def test_tool_calling_strict(use_responses_api: bool) -> None: model_with_invalid_tool_schema.invoke(query) # Test stream - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in model_with_tools.stream(query): full = chunk if full is None else full + chunk # type: ignore assert isinstance(full, AIMessage) @@ -960,7 +960,7 @@ class Foo(BaseModel): def test_stream_response_format() -> None: - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None chunks = [] for chunk in ChatOpenAI(model="gpt-5-nano").stream( "how are ya", response_format=Foo @@ -977,7 +977,7 @@ def test_stream_response_format() -> None: async def test_astream_response_format() -> None: - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None chunks = [] async for chunk in ChatOpenAI(model="gpt-5-nano").astream( "how are ya", response_format=Foo @@ -1048,7 +1048,7 @@ def test_structured_output_and_tools() -> None: assert isinstance(response.additional_kwargs["parsed"], ResponseFormat) # Test streaming tool calls - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream( "Generate a user name for Alice, black hair. Use the tool." ): diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py index a573d960152..a7df4ef7ef7 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py @@ -2,7 +2,7 @@ import json import os -from typing import Annotated, Any, Literal, Optional, cast +from typing import Annotated, Any, Literal, cast import openai import pytest @@ -22,7 +22,7 @@ from langchain_openai import ChatOpenAI, custom_tool MODEL_NAME = "gpt-4o-mini" -def _check_response(response: Optional[BaseMessage]) -> None: +def _check_response(response: BaseMessage | None) -> None: assert isinstance(response, AIMessage) assert isinstance(response.content, list) for block in response.content: @@ -69,7 +69,7 @@ def test_web_search(output_version: Literal["responses/v1", "v1"]) -> None: _check_response(first_response) # Test streaming - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream( "What was a positive news story from today?", tools=[{"type": "web_search_preview"}], @@ -123,7 +123,7 @@ async def test_web_search_async() -> None: assert response.response_metadata["status"] # Test streaming - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None async for chunk in llm.astream( "What was a positive news story from today?", tools=[{"type": "web_search_preview"}], @@ -194,7 +194,7 @@ def test_parsed_pydantic_schema( assert parsed.response # Test stream - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream("how are ya", response_format=Foo): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk @@ -212,7 +212,7 @@ async def test_parsed_pydantic_schema_async() -> None: assert parsed.response # Test stream - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None async for chunk in llm.astream("how are ya", response_format=Foo): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk @@ -233,7 +233,7 @@ def test_parsed_dict_schema(schema: Any) -> None: assert isinstance(parsed["response"], str) # Test stream - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream("how are ya", response_format=schema): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk @@ -280,7 +280,7 @@ async def test_parsed_dict_schema_async(schema: Any) -> None: assert isinstance(parsed["response"], str) # Test stream - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None async for chunk in llm.astream("how are ya", response_format=schema): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk @@ -325,7 +325,7 @@ def test_reasoning(output_version: Literal["v0", "responses/v1", "v1"]) -> None: llm = ChatOpenAI( model="o4-mini", reasoning={"effort": "low"}, output_version=output_version ) - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream("Hello"): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk @@ -386,7 +386,7 @@ def test_file_search() -> None: response = llm.invoke([input_message], tools=[tool]) _check_response(response) - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream([input_message], tools=[tool]): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk @@ -413,7 +413,7 @@ def test_stream_reasoning_summary( "role": "user", "content": "What was the third tallest buliding in the year 2000?", } - response_1: Optional[BaseMessageChunk] = None + response_1: BaseMessageChunk | None = None for chunk in llm.stream([message_1]): assert isinstance(chunk, AIMessageChunk) response_1 = chunk if response_1 is None else response_1 + chunk @@ -520,7 +520,7 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) - [{"type": "code_interpreter", "container": container_id}] ) - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm_with_tools.stream([input_message]): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk @@ -624,7 +624,7 @@ def test_mcp_builtin_zdr() -> None: "spec (modelcontextprotocol/modelcontextprotocol) support?" ), } - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm_with_tools.stream([input_message]): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk @@ -678,7 +678,7 @@ def test_mcp_builtin_zdr_v1() -> None: "spec (modelcontextprotocol/modelcontextprotocol) support?" ), } - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm_with_tools.stream([input_message]): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk @@ -756,7 +756,7 @@ def test_image_generation_streaming( "type", } - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream("Draw a random short word in green font.", tools=[tool]): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk @@ -801,7 +801,7 @@ def test_image_generation_streaming_v1() -> None: "status", } - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream("Draw a random short word in green font.", tools=[tool]): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk @@ -1029,7 +1029,7 @@ def test_custom_tool(output_version: Literal["responses/v1", "v1"]) -> None: assert isinstance(response, AIMessage) # Test streaming - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in llm.stream([input_message]): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index ff61301587d..8f9b16a0701 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -5,7 +5,7 @@ from __future__ import annotations import json from functools import partial from types import TracebackType -from typing import Any, Literal, Optional, Union, cast +from typing import Any, Literal, cast from unittest.mock import AsyncMock, MagicMock, patch import httpx @@ -291,9 +291,9 @@ class MockAsyncContextManager: async def __aexit__( self, - exc_type: Optional[type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ) -> None: pass @@ -319,9 +319,9 @@ class MockSyncContextManager: def __exit__( self, - exc_type: Optional[type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ) -> None: pass @@ -370,7 +370,7 @@ async def test_glm4_astream(mock_glm4_completion: list) -> None: mock_client.create = mock_create usage_chunk = mock_glm4_completion[-1] - usage_metadata: Optional[UsageMetadata] = None + usage_metadata: UsageMetadata | None = None with patch.object(llm, "async_client", mock_client): async for chunk in llm.astream("你的名字叫什么?只回答名字"): assert isinstance(chunk, AIMessageChunk) @@ -395,7 +395,7 @@ def test_glm4_stream(mock_glm4_completion: list) -> None: mock_client.create = mock_create usage_chunk = mock_glm4_completion[-1] - usage_metadata: Optional[UsageMetadata] = None + usage_metadata: UsageMetadata | None = None with patch.object(llm, "client", mock_client): for chunk in llm.stream("你的名字叫什么?只回答名字"): assert isinstance(chunk, AIMessageChunk) @@ -450,7 +450,7 @@ async def test_deepseek_astream(mock_deepseek_completion: list) -> None: mock_client.create = mock_create usage_chunk = mock_deepseek_completion[-1] - usage_metadata: Optional[UsageMetadata] = None + usage_metadata: UsageMetadata | None = None with patch.object(llm, "async_client", mock_client): async for chunk in llm.astream("你的名字叫什么?只回答名字"): assert isinstance(chunk, AIMessageChunk) @@ -474,7 +474,7 @@ def test_deepseek_stream(mock_deepseek_completion: list) -> None: mock_client.create = mock_create usage_chunk = mock_deepseek_completion[-1] - usage_metadata: Optional[UsageMetadata] = None + usage_metadata: UsageMetadata | None = None with patch.object(llm, "client", mock_client): for chunk in llm.stream("你的名字叫什么?只回答名字"): assert isinstance(chunk, AIMessageChunk) @@ -519,7 +519,7 @@ async def test_openai_astream(mock_openai_completion: list) -> None: mock_client.create = mock_create usage_chunk = mock_openai_completion[-1] - usage_metadata: Optional[UsageMetadata] = None + usage_metadata: UsageMetadata | None = None with patch.object(llm, "async_client", mock_client): async for chunk in llm.astream("你的名字叫什么?只回答名字"): assert isinstance(chunk, AIMessageChunk) @@ -547,7 +547,7 @@ def test_openai_stream(mock_openai_completion: list) -> None: mock_client.create = mock_create usage_chunk = mock_openai_completion[-1] - usage_metadata: Optional[UsageMetadata] = None + usage_metadata: UsageMetadata | None = None with patch.object(llm, "client", mock_client): for chunk in llm.stream("你的名字叫什么?只回答名字"): assert isinstance(chunk, AIMessageChunk) @@ -873,7 +873,7 @@ class MakeASandwich(BaseModel): ], ) @pytest.mark.parametrize("strict", [True, False, None]) -def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> None: +def test_bind_tools_tool_choice(tool_choice: Any, strict: bool | None) -> None: """Test passing in manually construct tool call message.""" llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) llm.bind_tools( @@ -888,10 +888,10 @@ def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> Non @pytest.mark.parametrize("include_raw", [True, False]) @pytest.mark.parametrize("strict", [True, False, None]) def test_with_structured_output( - schema: Union[type, dict[str, Any], None], + schema: type | dict[str, Any] | None, method: Literal["function_calling", "json_mode", "json_schema"], include_raw: bool, - strict: Optional[bool], + strict: bool | None, ) -> None: """Test passing in manually construct tool call message.""" if method == "json_mode": @@ -1088,7 +1088,7 @@ def test__convert_to_openai_response_format() -> None: @pytest.mark.parametrize("method", ["function_calling", "json_schema"]) @pytest.mark.parametrize("strict", [True, None]) def test_structured_output_strict( - method: Literal["function_calling", "json_schema"], strict: Optional[bool] + method: Literal["function_calling", "json_schema"], strict: bool | None ) -> None: """Test to verify structured output with strict=True.""" diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py b/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py index 2b3f568c26c..75d31cf5e54 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -728,7 +728,7 @@ def test_responses_stream(output_version: str, expected_content: list[dict]) -> mock_client.responses.create = mock_create - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None chunks = [] with patch.object(llm, "root_client", mock_client): for chunk in llm.stream("test"): diff --git a/libs/partners/openai/tests/unit_tests/fake/callbacks.py b/libs/partners/openai/tests/unit_tests/fake/callbacks.py index c17cc1f6acf..6ab92453548 100644 --- a/libs/partners/openai/tests/unit_tests/fake/callbacks.py +++ b/libs/partners/openai/tests/unit_tests/fake/callbacks.py @@ -3,7 +3,7 @@ from __future__ import annotations from itertools import chain -from typing import Any, Optional, Union +from typing import Any from uuid import UUID from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler @@ -26,7 +26,7 @@ class BaseFakeCallbackHandler(BaseModel): ignore_chat_model_: bool = False # to allow for similar callback handlers that are not technically equal - fake_id: Union[str, None] = None + fake_id: str | None = None # add finer-grained counters for easier debugging of failing tests chain_starts: int = 0 @@ -201,7 +201,7 @@ class FakeCallbackHandlerWithChatStart(FakeCallbackHandler): messages: list[list[BaseMessage]], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: assert all(isinstance(m, BaseMessage) for m in chain(*messages)) diff --git a/libs/partners/perplexity/langchain_perplexity/chat_models.py b/libs/partners/perplexity/langchain_perplexity/chat_models.py index 08ffe2539a5..2893f538ba7 100644 --- a/libs/partners/perplexity/langchain_perplexity/chat_models.py +++ b/libs/partners/perplexity/langchain_perplexity/chat_models.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging from collections.abc import Iterator, Mapping from operator import itemgetter -from typing import Any, Literal, Optional, TypeVar, Union +from typing import Any, Literal, TypeAlias import openai from langchain_core.callbacks import CallbackManagerForLLMRun @@ -38,9 +38,8 @@ from langchain_core.utils.pydantic import is_basemodel_subclass from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from typing_extensions import Self -_BM = TypeVar("_BM", bound=BaseModel) -_DictOrPydanticClass = Union[dict[str, Any], type[_BM], type] -_DictOrPydantic = Union[dict, _BM] +_DictOrPydanticClass: TypeAlias = dict[str, Any] | type[BaseModel] +_DictOrPydantic: TypeAlias = dict | BaseModel logger = logging.getLogger(__name__) @@ -151,20 +150,18 @@ class ChatPerplexity(BaseChatModel): """What sampling temperature to use.""" model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - pplx_api_key: Optional[SecretStr] = Field( + pplx_api_key: SecretStr | None = Field( default_factory=secret_from_env("PPLX_API_KEY", default=None), alias="api_key" ) """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" - request_timeout: Optional[Union[float, tuple[float, float]]] = Field( - None, alias="timeout" - ) + request_timeout: float | tuple[float, float] | None = Field(None, alias="timeout") """Timeout for requests to PerplexityChat completion API. Default is None.""" max_retries: int = 6 """Maximum number of retries to make when generating.""" streaming: bool = False """Whether to stream the results or not.""" - max_tokens: Optional[int] = None + max_tokens: int | None = None """Maximum number of tokens to generate.""" model_config = ConfigDict(populate_by_name=True) @@ -242,7 +239,7 @@ class ChatPerplexity(BaseChatModel): return message_dict def _create_message_dicts( - self, messages: list[BaseMessage], stop: Optional[list[str]] + self, messages: list[BaseMessage], stop: list[str] | None ) -> tuple[list[dict[str, Any]], dict[str, Any]]: params = dict(self._invocation_params) if stop is not None: @@ -284,8 +281,8 @@ class ChatPerplexity(BaseChatModel): def _stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) @@ -298,7 +295,7 @@ class ChatPerplexity(BaseChatModel): messages=message_dicts, stream=True, **params ) first_chunk = True - prev_total_usage: Optional[UsageMetadata] = None + prev_total_usage: UsageMetadata | None = None added_model_name: bool = False for chunk in stream_resp: @@ -308,7 +305,7 @@ class ChatPerplexity(BaseChatModel): if total_usage := chunk.get("usage"): lc_total_usage = _create_usage_metadata(total_usage) if prev_total_usage: - usage_metadata: Optional[UsageMetadata] = subtract_usage( + usage_metadata: UsageMetadata | None = subtract_usage( lc_total_usage, prev_total_usage ) else: @@ -355,8 +352,8 @@ class ChatPerplexity(BaseChatModel): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: if self.streaming: @@ -399,11 +396,11 @@ class ChatPerplexity(BaseChatModel): def with_structured_output( self, - schema: Optional[_DictOrPydanticClass] = None, + schema: _DictOrPydanticClass | None = None, *, method: Literal["json_schema"] = "json_schema", include_raw: bool = False, - strict: Optional[bool] = None, + strict: bool | None = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, _DictOrPydantic]: """Model wrapper that returns outputs formatted to match the given schema for Preplexity. diff --git a/libs/partners/perplexity/pyproject.toml b/libs/partners/perplexity/pyproject.toml index 2086bf2403d..f64597fcc14 100644 --- a/libs/partners/perplexity/pyproject.toml +++ b/libs/partners/perplexity/pyproject.toml @@ -67,7 +67,6 @@ docstring-code-format = true [tool.ruff.lint] select = ["E", "F", "I", "T201", "UP", "S"] -ignore = [ "UP007", "UP045"] [tool.ruff.lint.pydocstyle] convention = "google" diff --git a/libs/partners/perplexity/tests/unit_tests/test_chat_models.py b/libs/partners/perplexity/tests/unit_tests/test_chat_models.py index 934262847f9..a05f80d6c3f 100644 --- a/libs/partners/perplexity/tests/unit_tests/test_chat_models.py +++ b/libs/partners/perplexity/tests/unit_tests/test_chat_models.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, cast +from typing import Any, cast from unittest.mock import MagicMock from langchain_core.messages import AIMessageChunk, BaseMessage @@ -61,7 +61,7 @@ def test_perplexity_stream_includes_citations(mocker: MockerFixture) -> None: llm.client.chat.completions, "create", return_value=mock_stream ) stream = llm.stream("Hello langchain") - full: Optional[BaseMessage] = None + full: BaseMessage | None = None chunks_list = list(stream) # BaseChatModel.stream() adds an extra chunk after the final chunk from _stream assert len(chunks_list) == 4 @@ -124,7 +124,7 @@ def test_perplexity_stream_includes_citations_and_images(mocker: MockerFixture) llm.client.chat.completions, "create", return_value=mock_stream ) stream = llm.stream("Hello langchain") - full: Optional[BaseMessage] = None + full: BaseMessage | None = None chunks_list = list(stream) # BaseChatModel.stream() adds an extra chunk after the final chunk from _stream assert len(chunks_list) == 4 @@ -194,7 +194,7 @@ def test_perplexity_stream_includes_citations_and_related_questions( llm.client.chat.completions, "create", return_value=mock_stream ) stream = llm.stream("Hello langchain") - full: Optional[BaseMessage] = None + full: BaseMessage | None = None chunks_list = list(stream) # BaseChatModel.stream() adds an extra chunk after the final chunk from _stream assert len(chunks_list) == 4 @@ -258,7 +258,7 @@ def test_perplexity_stream_includes_citations_and_search_results( llm.client.chat.completions, "create", return_value=mock_stream ) stream = llm.stream("Hello langchain") - full: Optional[BaseMessage] = None + full: BaseMessage | None = None chunks_list = list(stream) # BaseChatModel.stream() adds an extra chunk after the final chunk from _stream assert len(chunks_list) == 4 diff --git a/libs/partners/prompty/langchain_prompty/core.py b/libs/partners/prompty/langchain_prompty/core.py index 6b9f2f40ed9..90d65795648 100644 --- a/libs/partners/prompty/langchain_prompty/core.py +++ b/libs/partners/prompty/langchain_prompty/core.py @@ -5,7 +5,7 @@ import json import os import re from pathlib import Path -from typing import Any, Generic, Literal, Optional, TypeVar, Union +from typing import Any, Generic, Literal, TypeVar import yaml from pydantic import BaseModel, ConfigDict, Field, FilePath @@ -24,7 +24,7 @@ class PropertySettings(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) type: Literal["string", "number", "array", "object", "boolean"] - default: Union[str, int, float, list, dict, bool, None] = Field(default=None) + default: str | int | float | list | dict | bool | None = Field(default=None) description: str = Field(default="") @@ -62,7 +62,7 @@ class Prompty(BaseModel): tags: list[str] = Field(default=[]) version: str = Field(default="") base: str = Field(default="") - basePrompty: Optional[Prompty] = Field(default=None) + basePrompty: Prompty | None = Field(default=None) # Model model: ModelSettings = Field(default_factory=ModelSettings) diff --git a/libs/partners/prompty/langchain_prompty/parsers.py b/libs/partners/prompty/langchain_prompty/parsers.py index 85e6cf0b45f..a29452fd7d9 100644 --- a/libs/partners/prompty/langchain_prompty/parsers.py +++ b/libs/partners/prompty/langchain_prompty/parsers.py @@ -1,6 +1,5 @@ import base64 import re -from typing import Union from langchain_core.messages import ( AIMessage, @@ -60,7 +59,7 @@ class PromptyChatParser(Invoker): "and .jpg / .jpeg are supported." ) - def parse_content(self, content: str) -> Union[str, list]: + def parse_content(self, content: str) -> str | list: """for parsing inline images""" # regular expression to parse markdown images image = r"(?P!\[[^\]]*\])\((?P.*?)(?=\"|\))\)" diff --git a/libs/partners/prompty/langchain_prompty/utils.py b/libs/partners/prompty/langchain_prompty/utils.py index 32fba7e743a..19c4dd837ec 100644 --- a/libs/partners/prompty/langchain_prompty/utils.py +++ b/libs/partners/prompty/langchain_prompty/utils.py @@ -1,6 +1,6 @@ import traceback from pathlib import Path -from typing import Any, Union +from typing import Any from .core import ( Frontmatter, @@ -166,7 +166,7 @@ def prepare( def run( prompt: Prompty, - content: Union[dict, list, str], + content: dict | list | str, configuration: dict[str, Any] = {}, parameters: dict[str, Any] = {}, raw: bool = False, @@ -218,7 +218,7 @@ def run( def execute( - prompt: Union[str, Prompty], + prompt: str | Prompty, configuration: dict[str, Any] = {}, parameters: dict[str, Any] = {}, inputs: dict[str, Any] = {}, diff --git a/libs/partners/prompty/pyproject.toml b/libs/partners/prompty/pyproject.toml index 252f1aac3c8..bf6f10c9265 100644 --- a/libs/partners/prompty/pyproject.toml +++ b/libs/partners/prompty/pyproject.toml @@ -65,7 +65,6 @@ docstring-code-format = true [tool.ruff.lint] select = ["E", "F", "I", "T201", "UP", "S"] -ignore = [ "UP007", "UP045" ] [tool.ruff.lint.pydocstyle] convention = "google" diff --git a/libs/partners/prompty/tests/unit_tests/fake_callback_handler.py b/libs/partners/prompty/tests/unit_tests/fake_callback_handler.py index bcff90185aa..bb520350b1a 100644 --- a/libs/partners/prompty/tests/unit_tests/fake_callback_handler.py +++ b/libs/partners/prompty/tests/unit_tests/fake_callback_handler.py @@ -1,7 +1,7 @@ """A fake callback handler for testing purposes.""" from itertools import chain -from typing import Any, Optional, Union +from typing import Any from uuid import UUID from langchain_core.callbacks import AsyncCallbackHandler, BaseCallbackHandler @@ -23,7 +23,7 @@ class BaseFakeCallbackHandler(BaseModel): ignore_chat_model_: bool = False # to allow for similar callback handlers that are not technically equal - fake_id: Union[str, None] = None + fake_id: str | None = None # add finer-grained counters for easier debugging of failing tests chain_starts: int = 0 @@ -271,7 +271,7 @@ class FakeCallbackHandlerWithChatStart(FakeCallbackHandler): messages: list[list[BaseMessage]], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: assert all(isinstance(m, BaseMessage) for m in chain(*messages)) diff --git a/libs/partners/prompty/tests/unit_tests/fake_chat_model.py b/libs/partners/prompty/tests/unit_tests/fake_chat_model.py index c709f561f7e..62ca6c698c9 100644 --- a/libs/partners/prompty/tests/unit_tests/fake_chat_model.py +++ b/libs/partners/prompty/tests/unit_tests/fake_chat_model.py @@ -1,7 +1,7 @@ """Fake Chat Model wrapper for testing purposes.""" import json -from typing import Any, Optional +from typing import Any from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -18,8 +18,8 @@ class FakeEchoPromptChatModel(SimpleChatModel): def _call( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> str: return json.dumps([message.model_dump() for message in messages]) @@ -27,8 +27,8 @@ class FakeEchoPromptChatModel(SimpleChatModel): async def _agenerate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: output_str = "fake response 2" diff --git a/libs/partners/prompty/tests/unit_tests/fake_output_parser.py b/libs/partners/prompty/tests/unit_tests/fake_output_parser.py index a3b1376602c..1e0f58b90a3 100644 --- a/libs/partners/prompty/tests/unit_tests/fake_output_parser.py +++ b/libs/partners/prompty/tests/unit_tests/fake_output_parser.py @@ -1,10 +1,8 @@ -from typing import Optional, Union - from langchain_classic.agents import AgentOutputParser from langchain_core.agents import AgentAction, AgentFinish -def extract_action_details(text: str) -> tuple[Optional[str], Optional[str]]: +def extract_action_details(text: str) -> tuple[str | None, str | None]: # Split the text into lines and strip whitespace lines = [line.strip() for line in text.strip().split("\n")] @@ -23,7 +21,7 @@ def extract_action_details(text: str) -> tuple[Optional[str], Optional[str]]: class FakeOutputParser(AgentOutputParser): - def parse(self, text: str) -> Union[AgentAction, AgentFinish]: + def parse(self, text: str) -> AgentAction | AgentFinish: action, input = extract_action_details(text) if action: diff --git a/libs/partners/qdrant/langchain_qdrant/_utils.py b/libs/partners/qdrant/langchain_qdrant/_utils.py index 0c77b939342..5ba46f3064a 100644 --- a/libs/partners/qdrant/langchain_qdrant/_utils.py +++ b/libs/partners/qdrant/langchain_qdrant/_utils.py @@ -1,8 +1,8 @@ -from typing import Union +from typing import TypeAlias import numpy as np -Matrix = Union[list[list[float]], list[np.ndarray], np.ndarray] +Matrix: TypeAlias = list[list[float]] | list[np.ndarray] | np.ndarray def maximal_marginal_relevance( @@ -44,8 +44,8 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: # noqa: N803 if len(X) == 0 or len(Y) == 0: return np.array([]) - x = np.array(X) - y = np.array(Y) + x: np.ndarray = np.array(X) + y: np.ndarray = np.array(Y) if x.shape[1] != y.shape[1]: msg = ( f"Number of columns in X and Y must be the same. X has shape {x.shape} " diff --git a/libs/partners/qdrant/langchain_qdrant/fastembed_sparse.py b/libs/partners/qdrant/langchain_qdrant/fastembed_sparse.py index 1b6a52eb486..234de095ecf 100644 --- a/libs/partners/qdrant/langchain_qdrant/fastembed_sparse.py +++ b/libs/partners/qdrant/langchain_qdrant/fastembed_sparse.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from langchain_qdrant.sparse_embeddings import SparseEmbeddings, SparseVector @@ -15,10 +15,10 @@ class FastEmbedSparse(SparseEmbeddings): self, model_name: str = "Qdrant/bm25", batch_size: int = 256, - cache_dir: Optional[str] = None, - threads: Optional[int] = None, - providers: Optional[Sequence[Any]] = None, - parallel: Optional[int] = None, + cache_dir: str | None = None, + threads: int | None = None, + providers: Sequence[Any] | None = None, + parallel: int | None = None, **kwargs: Any, ) -> None: """Sparse encoder implementation using FastEmbed. diff --git a/libs/partners/qdrant/langchain_qdrant/qdrant.py b/libs/partners/qdrant/langchain_qdrant/qdrant.py index 013dbd747d5..42ba86977cd 100644 --- a/libs/partners/qdrant/langchain_qdrant/qdrant.py +++ b/libs/partners/qdrant/langchain_qdrant/qdrant.py @@ -8,8 +8,6 @@ from operator import itemgetter from typing import ( TYPE_CHECKING, Any, - Optional, - Union, ) import numpy as np @@ -218,13 +216,13 @@ class QdrantVectorStore(VectorStore): self, client: QdrantClient, collection_name: str, - embedding: Optional[Embeddings] = None, + embedding: Embeddings | None = None, retrieval_mode: RetrievalMode = RetrievalMode.DENSE, vector_name: str = VECTOR_NAME, content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, distance: models.Distance = models.Distance.COSINE, - sparse_embedding: Optional[SparseEmbeddings] = None, + sparse_embedding: SparseEmbeddings | None = None, sparse_vector_name: str = SPARSE_VECTOR_NAME, validate_embeddings: bool = True, # noqa: FBT001, FBT002 validate_collection_config: bool = True, # noqa: FBT001, FBT002 @@ -278,7 +276,7 @@ class QdrantVectorStore(VectorStore): return self._client @property - def embeddings(self) -> Optional[Embeddings]: + def embeddings(self) -> Embeddings | None: """Get the dense embeddings instance that is being used. Returns: @@ -347,31 +345,31 @@ class QdrantVectorStore(VectorStore): def from_texts( cls: type[QdrantVectorStore], texts: list[str], - embedding: Optional[Embeddings] = None, - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str | int]] = None, - collection_name: Optional[str] = None, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, + embedding: Embeddings | None = None, + metadatas: list[dict] | None = None, + ids: Sequence[str | int] | None = None, + collection_name: str | None = None, + location: str | None = None, + url: str | None = None, + port: int | None = 6333, grpc_port: int = 6334, prefer_grpc: bool = False, # noqa: FBT001, FBT002 - https: Optional[bool] = None, # noqa: FBT001 - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[int] = None, - host: Optional[str] = None, - path: Optional[str] = None, + https: bool | None = None, # noqa: FBT001 + api_key: str | None = None, + prefix: str | None = None, + timeout: int | None = None, + host: str | None = None, + path: str | None = None, distance: models.Distance = models.Distance.COSINE, content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, vector_name: str = VECTOR_NAME, retrieval_mode: RetrievalMode = RetrievalMode.DENSE, - sparse_embedding: Optional[SparseEmbeddings] = None, + sparse_embedding: SparseEmbeddings | None = None, sparse_vector_name: str = SPARSE_VECTOR_NAME, - collection_create_options: Optional[dict[str, Any]] = None, - vector_params: Optional[dict[str, Any]] = None, - sparse_vector_params: Optional[dict[str, Any]] = None, + collection_create_options: dict[str, Any] | None = None, + vector_params: dict[str, Any] | None = None, + sparse_vector_params: dict[str, Any] | None = None, batch_size: int = 64, force_recreate: bool = False, # noqa: FBT001, FBT002 validate_embeddings: bool = True, # noqa: FBT001, FBT002 @@ -442,25 +440,25 @@ class QdrantVectorStore(VectorStore): def from_existing_collection( cls: type[QdrantVectorStore], collection_name: str, - embedding: Optional[Embeddings] = None, + embedding: Embeddings | None = None, retrieval_mode: RetrievalMode = RetrievalMode.DENSE, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, + location: str | None = None, + url: str | None = None, + port: int | None = 6333, grpc_port: int = 6334, prefer_grpc: bool = False, # noqa: FBT001, FBT002 - https: Optional[bool] = None, # noqa: FBT001 - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[int] = None, - host: Optional[str] = None, - path: Optional[str] = None, + https: bool | None = None, # noqa: FBT001 + api_key: str | None = None, + prefix: str | None = None, + timeout: int | None = None, + host: str | None = None, + path: str | None = None, distance: models.Distance = models.Distance.COSINE, content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, vector_name: str = VECTOR_NAME, sparse_vector_name: str = SPARSE_VECTOR_NAME, - sparse_embedding: Optional[SparseEmbeddings] = None, + sparse_embedding: SparseEmbeddings | None = None, validate_embeddings: bool = True, # noqa: FBT001, FBT002 validate_collection_config: bool = True, # noqa: FBT001, FBT002 **kwargs: Any, @@ -503,8 +501,8 @@ class QdrantVectorStore(VectorStore): def add_texts( # type: ignore[override] self, texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str | int]] = None, + metadatas: list[dict] | None = None, + ids: Sequence[str | int] | None = None, batch_size: int = 64, **kwargs: Any, ) -> list[str | int]: @@ -529,12 +527,12 @@ class QdrantVectorStore(VectorStore): self, query: str, k: int = 4, - filter: Optional[models.Filter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, + filter: models.Filter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, - hybrid_fusion: Optional[models.FusionQuery] = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, + hybrid_fusion: models.FusionQuery | None = None, **kwargs: Any, ) -> list[Document]: """Return docs most similar to query. @@ -560,12 +558,12 @@ class QdrantVectorStore(VectorStore): self, query: str, k: int = 4, - filter: Optional[models.Filter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, + filter: models.Filter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, - hybrid_fusion: Optional[models.FusionQuery] = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, + hybrid_fusion: models.FusionQuery | None = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs most similar to query. @@ -654,11 +652,11 @@ class QdrantVectorStore(VectorStore): self, embedding: list[float], k: int = 4, - filter: Optional[models.Filter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, + filter: models.Filter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs most similar to embedding vector. @@ -708,11 +706,11 @@ class QdrantVectorStore(VectorStore): self, embedding: list[float], k: int = 4, - filter: Optional[models.Filter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, + filter: models.Filter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[Document]: """Return docs most similar to embedding vector. @@ -739,10 +737,10 @@ class QdrantVectorStore(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[models.Filter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + filter: models.Filter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance with dense vectors. @@ -782,10 +780,10 @@ class QdrantVectorStore(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[models.Filter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + filter: models.Filter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance with dense vectors. @@ -816,10 +814,10 @@ class QdrantVectorStore(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[models.Filter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + filter: models.Filter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs selected using the maximal marginal relevance. @@ -870,9 +868,9 @@ class QdrantVectorStore(VectorStore): def delete( # type: ignore[override] self, - ids: Optional[list[str | int]] = None, + ids: list[str | int] | None = None, **kwargs: Any, - ) -> Optional[bool]: + ) -> bool | None: """Delete documents by their ids. Args: @@ -905,20 +903,20 @@ class QdrantVectorStore(VectorStore): @classmethod def construct_instance( cls: type[QdrantVectorStore], - embedding: Optional[Embeddings] = None, + embedding: Embeddings | None = None, retrieval_mode: RetrievalMode = RetrievalMode.DENSE, - sparse_embedding: Optional[SparseEmbeddings] = None, - client_options: Optional[dict[str, Any]] = None, - collection_name: Optional[str] = None, + sparse_embedding: SparseEmbeddings | None = None, + client_options: dict[str, Any] | None = None, + collection_name: str | None = None, distance: models.Distance = models.Distance.COSINE, content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, vector_name: str = VECTOR_NAME, sparse_vector_name: str = SPARSE_VECTOR_NAME, force_recreate: bool = False, # noqa: FBT001, FBT002 - collection_create_options: Optional[dict[str, Any]] = None, - vector_params: Optional[dict[str, Any]] = None, - sparse_vector_params: Optional[dict[str, Any]] = None, + collection_create_options: dict[str, Any] | None = None, + vector_params: dict[str, Any] | None = None, + sparse_vector_params: dict[str, Any] | None = None, validate_embeddings: bool = True, # noqa: FBT001, FBT002 validate_collection_config: bool = True, # noqa: FBT001, FBT002 ) -> QdrantVectorStore: @@ -1053,8 +1051,8 @@ class QdrantVectorStore(VectorStore): def _generate_batches( self, texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str | int]] = None, + metadatas: list[dict] | None = None, + ids: Sequence[str | int] | None = None, batch_size: int = 64, ) -> Generator[tuple[list[str | int], list[models.PointStruct]], Any, None]: texts_iterator = iter(texts) @@ -1088,7 +1086,7 @@ class QdrantVectorStore(VectorStore): @staticmethod def _build_payloads( texts: Iterable[str], - metadatas: Optional[list[dict]], + metadatas: list[dict] | None, content_payload_key: str, metadata_payload_key: str, ) -> list[dict]: @@ -1170,7 +1168,7 @@ class QdrantVectorStore(VectorStore): vector_name: str, sparse_vector_name: str, distance: models.Distance, - embedding: Optional[Embeddings], + embedding: Embeddings | None, ) -> None: if retrieval_mode == RetrievalMode.DENSE: cls._validate_collection_for_dense( @@ -1197,7 +1195,7 @@ class QdrantVectorStore(VectorStore): collection_name: str, vector_name: str, distance: models.Distance, - dense_embeddings: Union[Embeddings, list[float], None], + dense_embeddings: Embeddings | list[float] | None, ) -> None: collection_info = client.get_collection(collection_name=collection_name) vector_config = collection_info.config.params.vectors @@ -1289,8 +1287,8 @@ class QdrantVectorStore(VectorStore): def _validate_embeddings( cls: type[QdrantVectorStore], retrieval_mode: RetrievalMode, - embedding: Optional[Embeddings], - sparse_embedding: Optional[SparseEmbeddings], + embedding: Embeddings | None, + sparse_embedding: SparseEmbeddings | None, ) -> None: if retrieval_mode == RetrievalMode.DENSE and embedding is None: msg = "'embedding' cannot be None when retrieval mode is 'dense'" diff --git a/libs/partners/qdrant/langchain_qdrant/vectorstores.py b/libs/partners/qdrant/langchain_qdrant/vectorstores.py index 9ad20acff84..efeaee74337 100644 --- a/libs/partners/qdrant/langchain_qdrant/vectorstores.py +++ b/libs/partners/qdrant/langchain_qdrant/vectorstores.py @@ -7,7 +7,7 @@ import warnings from collections.abc import Callable from itertools import islice from operator import itemgetter -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any import numpy as np from langchain_core._api.deprecation import deprecated @@ -24,8 +24,8 @@ from langchain_qdrant._utils import maximal_marginal_relevance if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator, Iterable, Sequence - DictFilter = dict[str, Union[str, int, bool, dict, list]] - MetadataFilter = Union[DictFilter, models.Filter] + DictFilter = dict[str, str | int | bool | dict | list] + MetadataFilter = DictFilter | models.Filter class QdrantException(Exception): # noqa: N818 @@ -74,19 +74,19 @@ class Qdrant(VectorStore): CONTENT_KEY: str = "page_content" METADATA_KEY: str = "metadata" - VECTOR_NAME: Optional[str] = None + VECTOR_NAME: str | None = None def __init__( self, client: Any, collection_name: str, - embeddings: Optional[Embeddings] = None, + embeddings: Embeddings | None = None, content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, distance_strategy: str = "COSINE", - vector_name: Optional[str] = VECTOR_NAME, - async_client: Optional[Any] = None, - embedding_function: Optional[Callable] = None, # deprecated + vector_name: str | None = VECTOR_NAME, + async_client: Any | None = None, + embedding_function: Callable | None = None, # deprecated ) -> None: """Initialize with necessary components.""" if not isinstance(client, QdrantClient): @@ -117,7 +117,7 @@ class Qdrant(VectorStore): self._embeddings = embeddings self._embeddings_function = embedding_function self.client: QdrantClient = client - self.async_client: Optional[AsyncQdrantClient] = async_client + self.async_client: AsyncQdrantClient | None = async_client self.collection_name = collection_name self.content_payload_key = content_payload_key or self.CONTENT_KEY self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY @@ -142,14 +142,14 @@ class Qdrant(VectorStore): self.distance_strategy = distance_strategy.upper() @property - def embeddings(self) -> Optional[Embeddings]: + def embeddings(self) -> Embeddings | None: return self._embeddings def add_texts( self, texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, + metadatas: list[dict] | None = None, + ids: Sequence[str] | None = None, batch_size: int = 64, **kwargs: Any, ) -> list[str]: @@ -185,8 +185,8 @@ class Qdrant(VectorStore): async def aadd_texts( self, texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, + metadatas: list[dict] | None = None, + ids: Sequence[str] | None = None, batch_size: int = 64, **kwargs: Any, ) -> list[str]: @@ -228,11 +228,11 @@ class Qdrant(VectorStore): self, query: str, k: int = 4, - filter: Optional[MetadataFilter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, + filter: MetadataFilter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[Document]: """Return docs most similar to query. @@ -287,7 +287,7 @@ class Qdrant(VectorStore): self, query: str, k: int = 4, - filter: Optional[MetadataFilter] = None, # noqa: A002 + filter: MetadataFilter | None = None, # noqa: A002 **kwargs: Any, ) -> list[Document]: """Return docs most similar to query. @@ -309,11 +309,11 @@ class Qdrant(VectorStore): self, query: str, k: int = 4, - filter: Optional[MetadataFilter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, + filter: MetadataFilter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs most similar to query. @@ -367,11 +367,11 @@ class Qdrant(VectorStore): self, query: str, k: int = 4, - filter: Optional[MetadataFilter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, + filter: MetadataFilter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs most similar to query. @@ -426,11 +426,11 @@ class Qdrant(VectorStore): self, embedding: list[float], k: int = 4, - filter: Optional[MetadataFilter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, + filter: MetadataFilter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[Document]: """Return docs most similar to embedding vector. @@ -485,11 +485,11 @@ class Qdrant(VectorStore): self, embedding: list[float], k: int = 4, - filter: Optional[MetadataFilter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, + filter: MetadataFilter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[Document]: """Return docs most similar to embedding vector. @@ -544,11 +544,11 @@ class Qdrant(VectorStore): self, embedding: list[float], k: int = 4, - filter: Optional[MetadataFilter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, + filter: MetadataFilter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs most similar to embedding vector. @@ -633,11 +633,11 @@ class Qdrant(VectorStore): self, embedding: list[float], k: int = 4, - filter: Optional[MetadataFilter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, + filter: MetadataFilter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs most similar to embedding vector. @@ -729,10 +729,10 @@ class Qdrant(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[MetadataFilter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + filter: MetadataFilter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance. @@ -795,10 +795,10 @@ class Qdrant(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[MetadataFilter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + filter: MetadataFilter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance. @@ -861,10 +861,10 @@ class Qdrant(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[MetadataFilter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + filter: MetadataFilter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance. @@ -926,10 +926,10 @@ class Qdrant(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[MetadataFilter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + filter: MetadataFilter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance. @@ -993,10 +993,10 @@ class Qdrant(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[MetadataFilter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + filter: MetadataFilter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs selected using the maximal marginal relevance. @@ -1085,10 +1085,10 @@ class Qdrant(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[MetadataFilter] = None, # noqa: A002 - search_params: Optional[models.SearchParams] = None, - score_threshold: Optional[float] = None, - consistency: Optional[models.ReadConsistency] = None, + filter: MetadataFilter | None = None, # noqa: A002 + search_params: models.SearchParams | None = None, + score_threshold: float | None = None, + consistency: models.ReadConsistency | None = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs selected using the maximal marginal relevance. @@ -1159,7 +1159,7 @@ class Qdrant(VectorStore): for i in mmr_selected ] - def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[bool]: + def delete(self, ids: list[str] | None = None, **kwargs: Any) -> bool | None: """Delete by vector ID or other criteria. Args: @@ -1177,9 +1177,7 @@ class Qdrant(VectorStore): return result.status == models.UpdateStatus.COMPLETED @sync_call_fallback - async def adelete( - self, ids: Optional[list[str]] = None, **kwargs: Any - ) -> Optional[bool]: + async def adelete(self, ids: list[str] | None = None, **kwargs: Any) -> bool | None: """Delete by vector ID or other criteria. Args: @@ -1208,35 +1206,35 @@ class Qdrant(VectorStore): cls: type[Qdrant], texts: list[str], embedding: Embeddings, - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, + metadatas: list[dict] | None = None, + ids: Sequence[str] | None = None, + location: str | None = None, + url: str | None = None, + port: int | None = 6333, grpc_port: int = 6334, prefer_grpc: bool = False, # noqa: FBT001, FBT002 - https: Optional[bool] = None, # noqa: FBT001 - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[int] = None, - host: Optional[str] = None, - path: Optional[str] = None, - collection_name: Optional[str] = None, + https: bool | None = None, # noqa: FBT001 + api_key: str | None = None, + prefix: str | None = None, + timeout: int | None = None, + host: str | None = None, + path: str | None = None, + collection_name: str | None = None, distance_func: str = "Cosine", content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, - vector_name: Optional[str] = VECTOR_NAME, + vector_name: str | None = VECTOR_NAME, batch_size: int = 64, - shard_number: Optional[int] = None, - replication_factor: Optional[int] = None, - write_consistency_factor: Optional[int] = None, - on_disk_payload: Optional[bool] = None, # noqa: FBT001 - hnsw_config: Optional[models.HnswConfigDiff] = None, - optimizers_config: Optional[models.OptimizersConfigDiff] = None, - wal_config: Optional[models.WalConfigDiff] = None, - quantization_config: Optional[models.QuantizationConfig] = None, - init_from: Optional[models.InitFrom] = None, - on_disk: Optional[bool] = None, # noqa: FBT001 + shard_number: int | None = None, + replication_factor: int | None = None, + write_consistency_factor: int | None = None, + on_disk_payload: bool | None = None, # noqa: FBT001 + hnsw_config: models.HnswConfigDiff | None = None, + optimizers_config: models.OptimizersConfigDiff | None = None, + wal_config: models.WalConfigDiff | None = None, + quantization_config: models.QuantizationConfig | None = None, + init_from: models.InitFrom | None = None, + on_disk: bool | None = None, # noqa: FBT001 force_recreate: bool = False, # noqa: FBT001, FBT002 **kwargs: Any, ) -> Qdrant: @@ -1388,22 +1386,22 @@ class Qdrant(VectorStore): def from_existing_collection( cls: type[Qdrant], embedding: Embeddings, - path: Optional[str] = None, - collection_name: Optional[str] = None, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, + path: str | None = None, + collection_name: str | None = None, + location: str | None = None, + url: str | None = None, + port: int | None = 6333, grpc_port: int = 6334, prefer_grpc: bool = False, # noqa: FBT001, FBT002 - https: Optional[bool] = None, # noqa: FBT001 - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[int] = None, - host: Optional[str] = None, + https: bool | None = None, # noqa: FBT001 + api_key: str | None = None, + prefix: str | None = None, + timeout: int | None = None, + host: str | None = None, content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, distance_strategy: str = "COSINE", - vector_name: Optional[str] = VECTOR_NAME, + vector_name: str | None = VECTOR_NAME, **kwargs: Any, ) -> Qdrant: """Get instance of an existing Qdrant collection. @@ -1446,35 +1444,35 @@ class Qdrant(VectorStore): cls: type[Qdrant], texts: list[str], embedding: Embeddings, - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, + metadatas: list[dict] | None = None, + ids: Sequence[str] | None = None, + location: str | None = None, + url: str | None = None, + port: int | None = 6333, grpc_port: int = 6334, prefer_grpc: bool = False, # noqa: FBT001, FBT002 - https: Optional[bool] = None, # noqa: FBT001 - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[int] = None, - host: Optional[str] = None, - path: Optional[str] = None, - collection_name: Optional[str] = None, + https: bool | None = None, # noqa: FBT001 + api_key: str | None = None, + prefix: str | None = None, + timeout: int | None = None, + host: str | None = None, + path: str | None = None, + collection_name: str | None = None, distance_func: str = "Cosine", content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, - vector_name: Optional[str] = VECTOR_NAME, + vector_name: str | None = VECTOR_NAME, batch_size: int = 64, - shard_number: Optional[int] = None, - replication_factor: Optional[int] = None, - write_consistency_factor: Optional[int] = None, - on_disk_payload: Optional[bool] = None, # noqa: FBT001 - hnsw_config: Optional[models.HnswConfigDiff] = None, - optimizers_config: Optional[models.OptimizersConfigDiff] = None, - wal_config: Optional[models.WalConfigDiff] = None, - quantization_config: Optional[models.QuantizationConfig] = None, - init_from: Optional[models.InitFrom] = None, - on_disk: Optional[bool] = None, # noqa: FBT001 + shard_number: int | None = None, + replication_factor: int | None = None, + write_consistency_factor: int | None = None, + on_disk_payload: bool | None = None, # noqa: FBT001 + hnsw_config: models.HnswConfigDiff | None = None, + optimizers_config: models.OptimizersConfigDiff | None = None, + wal_config: models.WalConfigDiff | None = None, + quantization_config: models.QuantizationConfig | None = None, + init_from: models.InitFrom | None = None, + on_disk: bool | None = None, # noqa: FBT001 force_recreate: bool = False, # noqa: FBT001, FBT002 **kwargs: Any, ) -> Qdrant: @@ -1631,32 +1629,32 @@ class Qdrant(VectorStore): cls: type[Qdrant], texts: list[str], embedding: Embeddings, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, + location: str | None = None, + url: str | None = None, + port: int | None = 6333, grpc_port: int = 6334, prefer_grpc: bool = False, # noqa: FBT001, FBT002 - https: Optional[bool] = None, # noqa: FBT001 - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[int] = None, - host: Optional[str] = None, - path: Optional[str] = None, - collection_name: Optional[str] = None, + https: bool | None = None, # noqa: FBT001 + api_key: str | None = None, + prefix: str | None = None, + timeout: int | None = None, + host: str | None = None, + path: str | None = None, + collection_name: str | None = None, distance_func: str = "Cosine", content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, - vector_name: Optional[str] = VECTOR_NAME, - shard_number: Optional[int] = None, - replication_factor: Optional[int] = None, - write_consistency_factor: Optional[int] = None, - on_disk_payload: Optional[bool] = None, # noqa: FBT001 - hnsw_config: Optional[models.HnswConfigDiff] = None, - optimizers_config: Optional[models.OptimizersConfigDiff] = None, - wal_config: Optional[models.WalConfigDiff] = None, - quantization_config: Optional[models.QuantizationConfig] = None, - init_from: Optional[models.InitFrom] = None, - on_disk: Optional[bool] = None, # noqa: FBT001 + vector_name: str | None = VECTOR_NAME, + shard_number: int | None = None, + replication_factor: int | None = None, + write_consistency_factor: int | None = None, + on_disk_payload: bool | None = None, # noqa: FBT001 + hnsw_config: models.HnswConfigDiff | None = None, + optimizers_config: models.OptimizersConfigDiff | None = None, + wal_config: models.WalConfigDiff | None = None, + quantization_config: models.QuantizationConfig | None = None, + init_from: models.InitFrom | None = None, + on_disk: bool | None = None, # noqa: FBT001 force_recreate: bool = False, # noqa: FBT001, FBT002 **kwargs: Any, ) -> Qdrant: @@ -1796,32 +1794,32 @@ class Qdrant(VectorStore): cls: type[Qdrant], texts: list[str], embedding: Embeddings, - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, + location: str | None = None, + url: str | None = None, + port: int | None = 6333, grpc_port: int = 6334, prefer_grpc: bool = False, # noqa: FBT001, FBT002 - https: Optional[bool] = None, # noqa: FBT001 - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[int] = None, - host: Optional[str] = None, - path: Optional[str] = None, - collection_name: Optional[str] = None, + https: bool | None = None, # noqa: FBT001 + api_key: str | None = None, + prefix: str | None = None, + timeout: int | None = None, + host: str | None = None, + path: str | None = None, + collection_name: str | None = None, distance_func: str = "Cosine", content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, - vector_name: Optional[str] = VECTOR_NAME, - shard_number: Optional[int] = None, - replication_factor: Optional[int] = None, - write_consistency_factor: Optional[int] = None, - on_disk_payload: Optional[bool] = None, # noqa: FBT001 - hnsw_config: Optional[models.HnswConfigDiff] = None, - optimizers_config: Optional[models.OptimizersConfigDiff] = None, - wal_config: Optional[models.WalConfigDiff] = None, - quantization_config: Optional[models.QuantizationConfig] = None, - init_from: Optional[models.InitFrom] = None, - on_disk: Optional[bool] = None, # noqa: FBT001 + vector_name: str | None = VECTOR_NAME, + shard_number: int | None = None, + replication_factor: int | None = None, + write_consistency_factor: int | None = None, + on_disk_payload: bool | None = None, # noqa: FBT001 + hnsw_config: models.HnswConfigDiff | None = None, + optimizers_config: models.OptimizersConfigDiff | None = None, + wal_config: models.WalConfigDiff | None = None, + quantization_config: models.QuantizationConfig | None = None, + init_from: models.InitFrom | None = None, + on_disk: bool | None = None, # noqa: FBT001 force_recreate: bool = False, # noqa: FBT001, FBT002 **kwargs: Any, ) -> Qdrant: @@ -2034,7 +2032,7 @@ class Qdrant(VectorStore): def _build_payloads( cls, texts: Iterable[str], - metadatas: Optional[list[dict]], + metadatas: list[dict] | None, content_payload_key: str, metadata_payload_key: str, ) -> list[dict]: @@ -2095,8 +2093,8 @@ class Qdrant(VectorStore): return out def _qdrant_filter_from_dict( - self, filter_: Optional[DictFilter] - ) -> Optional[models.Filter]: + self, filter_: DictFilter | None + ) -> models.Filter | None: if not filter_: return None @@ -2211,8 +2209,8 @@ class Qdrant(VectorStore): def _generate_rest_batches( self, texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, + metadatas: list[dict] | None = None, + ids: Sequence[str] | None = None, batch_size: int = 64, ) -> Generator[tuple[list[str], list[models.PointStruct]], None, None]: texts_iterator = iter(texts) @@ -2252,8 +2250,8 @@ class Qdrant(VectorStore): async def _agenerate_rest_batches( self, texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, + metadatas: list[dict] | None = None, + ids: Sequence[str] | None = None, batch_size: int = 64, ) -> AsyncGenerator[tuple[list[str], list[models.PointStruct]], None]: texts_iterator = iter(texts) @@ -2292,19 +2290,19 @@ class Qdrant(VectorStore): @staticmethod def _generate_clients( - location: Optional[str] = None, - url: Optional[str] = None, - port: Optional[int] = 6333, + location: str | None = None, + url: str | None = None, + port: int | None = 6333, grpc_port: int = 6334, prefer_grpc: bool = False, # noqa: FBT001, FBT002 - https: Optional[bool] = None, # noqa: FBT001 - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[int] = None, - host: Optional[str] = None, - path: Optional[str] = None, + https: bool | None = None, # noqa: FBT001 + api_key: str | None = None, + prefix: str | None = None, + timeout: int | None = None, + host: str | None = None, + path: str | None = None, **kwargs: Any, - ) -> tuple[QdrantClient, Optional[AsyncQdrantClient]]: + ) -> tuple[QdrantClient, AsyncQdrantClient | None]: if api_key is None: api_key = os.getenv("QDRANT_API_KEY") diff --git a/libs/partners/qdrant/pyproject.toml b/libs/partners/qdrant/pyproject.toml index dc03421f413..3985fcc99cb 100644 --- a/libs/partners/qdrant/pyproject.toml +++ b/libs/partners/qdrant/pyproject.toml @@ -69,8 +69,6 @@ ignore = [ "S112", # Rarely useful "RUF012", # Doesn't play well with Pydantic "SLF001", # Private member access - "UP007", # pyupgrade: non-pep604-annotation-union - "UP045", # pyupgrade: non-pep604-annotation-optional "PLR0913", # Function has too many arguments "C901", # Complex functions "TC003", diff --git a/libs/partners/qdrant/tests/integration_tests/async_api/test_add_texts.py b/libs/partners/qdrant/tests/integration_tests/async_api/test_add_texts.py index 704d27e531a..958a761283c 100644 --- a/libs/partners/qdrant/tests/integration_tests/async_api/test_add_texts.py +++ b/libs/partners/qdrant/tests/integration_tests/async_api/test_add_texts.py @@ -2,7 +2,6 @@ from __future__ import annotations import os import uuid -from typing import Optional import pytest # type: ignore[import-not-found] @@ -34,7 +33,7 @@ async def test_qdrant_aadd_texts_returns_all_ids( @pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_aadd_texts_stores_duplicated_texts( - vector_name: Optional[str], qdrant_location: str + vector_name: str | None, qdrant_location: str ) -> None: """Test end to end Qdrant.aadd_texts stores duplicated texts separately.""" from qdrant_client import QdrantClient diff --git a/libs/partners/qdrant/tests/integration_tests/async_api/test_from_texts.py b/libs/partners/qdrant/tests/integration_tests/async_api/test_from_texts.py index e338a835a37..88fe2a8c1c7 100644 --- a/libs/partners/qdrant/tests/integration_tests/async_api/test_from_texts.py +++ b/libs/partners/qdrant/tests/integration_tests/async_api/test_from_texts.py @@ -2,7 +2,6 @@ from __future__ import annotations import os import uuid -from typing import Optional import pytest # type: ignore[import-not-found] from langchain_core.documents import Document @@ -38,7 +37,7 @@ async def test_qdrant_from_texts_stores_duplicated_texts(qdrant_location: str) - @pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_from_texts_stores_ids( - batch_size: int, vector_name: Optional[str], qdrant_location: str + batch_size: int, vector_name: str | None, qdrant_location: str ) -> None: """Test end to end Qdrant.afrom_texts stores provided ids.""" collection_name = uuid.uuid4().hex @@ -90,7 +89,7 @@ async def test_qdrant_from_texts_stores_embeddings_as_named_vectors( @pytest.mark.parametrize("location", qdrant_locations(use_in_memory=False)) @pytest.mark.parametrize("vector_name", [None, "custom-vector"]) async def test_qdrant_from_texts_reuses_same_collection( - location: str, vector_name: Optional[str] + location: str, vector_name: str | None ) -> None: """Test if Qdrant.afrom_texts reuses the same collection.""" collection_name = uuid.uuid4().hex @@ -120,7 +119,7 @@ async def test_qdrant_from_texts_reuses_same_collection( @pytest.mark.parametrize("vector_name", [None, "custom-vector"]) async def test_qdrant_from_texts_raises_error_on_different_dimensionality( location: str, - vector_name: Optional[str], + vector_name: str | None, ) -> None: """Test if Qdrant.afrom_texts raises an exception if dimensionality does not match. @@ -156,8 +155,8 @@ async def test_qdrant_from_texts_raises_error_on_different_dimensionality( ) async def test_qdrant_from_texts_raises_error_on_different_vector_name( location: str, - first_vector_name: Optional[str], - second_vector_name: Optional[str], + first_vector_name: str | None, + second_vector_name: str | None, ) -> None: """Test if Qdrant.afrom_texts raises an exception if vector name does not match.""" collection_name = uuid.uuid4().hex @@ -209,7 +208,7 @@ async def test_qdrant_from_texts_raises_error_on_different_distance( @pytest.mark.parametrize("vector_name", [None, "custom-vector"]) async def test_qdrant_from_texts_recreates_collection_on_force_recreate( location: str, - vector_name: Optional[str], + vector_name: str | None, ) -> None: """Test if Qdrant.afrom_texts recreates the collection even if config mismatches.""" from qdrant_client import QdrantClient diff --git a/libs/partners/qdrant/tests/integration_tests/async_api/test_max_marginal_relevance.py b/libs/partners/qdrant/tests/integration_tests/async_api/test_max_marginal_relevance.py index 896d5e2b77a..afc7f5223ac 100644 --- a/libs/partners/qdrant/tests/integration_tests/async_api/test_max_marginal_relevance.py +++ b/libs/partners/qdrant/tests/integration_tests/async_api/test_max_marginal_relevance.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Optional - import pytest # type: ignore[import-not-found] from langchain_core.documents import Document @@ -24,7 +22,7 @@ async def test_qdrant_max_marginal_relevance_search( batch_size: int, content_payload_key: str, metadata_payload_key: str, - vector_name: Optional[str], + vector_name: str | None, qdrant_location: str, ) -> None: """Test end to end construction and MRR search.""" diff --git a/libs/partners/qdrant/tests/integration_tests/async_api/test_similarity_search.py b/libs/partners/qdrant/tests/integration_tests/async_api/test_similarity_search.py index ba1d41361d3..8353999a1e9 100644 --- a/libs/partners/qdrant/tests/integration_tests/async_api/test_similarity_search.py +++ b/libs/partners/qdrant/tests/integration_tests/async_api/test_similarity_search.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Optional - import numpy as np import pytest # type: ignore[import-not-found] from langchain_core.documents import Document @@ -23,7 +21,7 @@ async def test_qdrant_similarity_search( batch_size: int, content_payload_key: str, metadata_payload_key: str, - vector_name: Optional[str], + vector_name: str | None, qdrant_location: str, ) -> None: """Test end to end construction and search.""" @@ -50,7 +48,7 @@ async def test_qdrant_similarity_search_by_vector( batch_size: int, content_payload_key: str, metadata_payload_key: str, - vector_name: Optional[str], + vector_name: str | None, qdrant_location: str, ) -> None: """Test end to end construction and search.""" @@ -78,7 +76,7 @@ async def test_qdrant_similarity_search_with_score_by_vector( batch_size: int, content_payload_key: str, metadata_payload_key: str, - vector_name: Optional[str], + vector_name: str | None, qdrant_location: str, ) -> None: """Test end to end construction and search.""" @@ -104,7 +102,7 @@ async def test_qdrant_similarity_search_with_score_by_vector( @pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_similarity_search_filters( - batch_size: int, vector_name: Optional[str], qdrant_location: str + batch_size: int, vector_name: str | None, qdrant_location: str ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -138,7 +136,7 @@ async def test_qdrant_similarity_search_filters( @pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_similarity_search_with_relevance_score_no_threshold( - vector_name: Optional[str], + vector_name: str | None, qdrant_location: str, ) -> None: """Test end to end construction and search.""" @@ -166,7 +164,7 @@ async def test_qdrant_similarity_search_with_relevance_score_no_threshold( @pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_similarity_search_with_relevance_score_with_threshold( - vector_name: Optional[str], + vector_name: str | None, qdrant_location: str, ) -> None: """Test end to end construction and search.""" @@ -195,7 +193,7 @@ async def test_qdrant_similarity_search_with_relevance_score_with_threshold( @pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_similarity_search_with_relevance_score_with_threshold_and_filter( - vector_name: Optional[str], + vector_name: str | None, qdrant_location: str, ) -> None: """Test end to end construction and search.""" @@ -230,7 +228,7 @@ async def test_similarity_search_with_relevance_score_with_threshold_and_filter( @pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_similarity_search_filters_with_qdrant_filters( - vector_name: Optional[str], + vector_name: str | None, qdrant_location: str, ) -> None: """Test end to end construction and search.""" diff --git a/libs/partners/qdrant/tests/integration_tests/qdrant_vector_store/test_add_texts.py b/libs/partners/qdrant/tests/integration_tests/qdrant_vector_store/test_add_texts.py index 25d0554b42e..61556718d71 100644 --- a/libs/partners/qdrant/tests/integration_tests/qdrant_vector_store/test_add_texts.py +++ b/libs/partners/qdrant/tests/integration_tests/qdrant_vector_store/test_add_texts.py @@ -1,7 +1,6 @@ from __future__ import annotations import uuid -from typing import Union import pytest from langchain_core.documents import Document @@ -119,7 +118,7 @@ def test_qdrant_add_texts_stores_ids( batch_size: int, ) -> None: """Test end to end Qdrant.add_texts stores provided ids.""" - ids: list[Union[str, int]] = [ + ids: list[str | int] = [ "fa38d572-4c31-4579-aedc-1960d79df6df", 432, 432145435, diff --git a/libs/partners/qdrant/tests/integration_tests/qdrant_vector_store/test_from_texts.py b/libs/partners/qdrant/tests/integration_tests/qdrant_vector_store/test_from_texts.py index 2a11f7c0618..3fae4976628 100644 --- a/libs/partners/qdrant/tests/integration_tests/qdrant_vector_store/test_from_texts.py +++ b/libs/partners/qdrant/tests/integration_tests/qdrant_vector_store/test_from_texts.py @@ -1,7 +1,6 @@ from __future__ import annotations import uuid -from typing import Union import pytest from langchain_core.documents import Document @@ -51,7 +50,7 @@ def test_qdrant_from_texts_stores_ids( ) -> None: """Test end to end Qdrant.from_texts stores provided ids.""" collection_name = uuid.uuid4().hex - ids: list[Union[str, int]] = [ + ids: list[str | int] = [ "fa38d572-4c31-4579-aedc-1960d79df6df", 786, ] diff --git a/libs/partners/qdrant/tests/integration_tests/test_add_texts.py b/libs/partners/qdrant/tests/integration_tests/test_add_texts.py index 5ed5ca3722d..d0b38be3db9 100644 --- a/libs/partners/qdrant/tests/integration_tests/test_add_texts.py +++ b/libs/partners/qdrant/tests/integration_tests/test_add_texts.py @@ -1,7 +1,6 @@ from __future__ import annotations import uuid -from typing import Optional import pytest # type: ignore[import-not-found] from langchain_core.documents import Document @@ -16,7 +15,7 @@ from tests.integration_tests.common import ( @pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("vector_name", [None, "my-vector"]) def test_qdrant_add_documents_extends_existing_collection( - batch_size: int, vector_name: Optional[str] + batch_size: int, vector_name: str | None ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -55,7 +54,7 @@ def test_qdrant_add_texts_returns_all_ids(batch_size: int) -> None: @pytest.mark.parametrize("vector_name", [None, "my-vector"]) -def test_qdrant_add_texts_stores_duplicated_texts(vector_name: Optional[str]) -> None: +def test_qdrant_add_texts_stores_duplicated_texts(vector_name: str | None) -> None: """Test end to end Qdrant.add_texts stores duplicated texts separately.""" from qdrant_client import QdrantClient from qdrant_client.http import models as rest diff --git a/libs/partners/qdrant/tests/integration_tests/test_embedding_interface.py b/libs/partners/qdrant/tests/integration_tests/test_embedding_interface.py index 60cf2dd9747..9a1e15eeec0 100644 --- a/libs/partners/qdrant/tests/integration_tests/test_embedding_interface.py +++ b/libs/partners/qdrant/tests/integration_tests/test_embedding_interface.py @@ -2,7 +2,7 @@ from __future__ import annotations import uuid from collections.abc import Callable -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import pytest # type: ignore[import-not-found] @@ -22,7 +22,7 @@ if TYPE_CHECKING: ], ) def test_qdrant_embedding_interface( - embeddings: Optional[Embeddings], embedding_function: Optional[Callable] + embeddings: Embeddings | None, embedding_function: Callable | None ) -> None: """Test Qdrant may accept different types for embeddings.""" from qdrant_client import QdrantClient @@ -46,7 +46,7 @@ def test_qdrant_embedding_interface( ], ) def test_qdrant_embedding_interface_raises_value_error( - embeddings: Optional[Embeddings], embedding_function: Optional[Callable] + embeddings: Embeddings | None, embedding_function: Callable | None ) -> None: """Test Qdrant requires only one method for embeddings.""" from qdrant_client import QdrantClient diff --git a/libs/partners/qdrant/tests/integration_tests/test_from_texts.py b/libs/partners/qdrant/tests/integration_tests/test_from_texts.py index 1b61029aedc..cbd882fd636 100644 --- a/libs/partners/qdrant/tests/integration_tests/test_from_texts.py +++ b/libs/partners/qdrant/tests/integration_tests/test_from_texts.py @@ -2,7 +2,6 @@ from __future__ import annotations import tempfile import uuid -from typing import Optional import pytest # type: ignore[import-not-found] from langchain_core.documents import Document @@ -37,9 +36,7 @@ def test_qdrant_from_texts_stores_duplicated_texts() -> None: @pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("vector_name", [None, "my-vector"]) -def test_qdrant_from_texts_stores_ids( - batch_size: int, vector_name: Optional[str] -) -> None: +def test_qdrant_from_texts_stores_ids(batch_size: int, vector_name: str | None) -> None: """Test end to end Qdrant.from_texts stores provided ids.""" from qdrant_client import QdrantClient @@ -91,7 +88,7 @@ def test_qdrant_from_texts_stores_embeddings_as_named_vectors(vector_name: str) @pytest.mark.parametrize("vector_name", [None, "custom-vector"]) -def test_qdrant_from_texts_reuses_same_collection(vector_name: Optional[str]) -> None: +def test_qdrant_from_texts_reuses_same_collection(vector_name: str | None) -> None: """Test if Qdrant.from_texts reuses the same collection.""" from qdrant_client import QdrantClient @@ -122,7 +119,7 @@ def test_qdrant_from_texts_reuses_same_collection(vector_name: Optional[str]) -> @pytest.mark.parametrize("vector_name", [None, "custom-vector"]) def test_qdrant_from_texts_raises_error_on_different_dimensionality( - vector_name: Optional[str], + vector_name: str | None, ) -> None: """Test if Qdrant.from_texts raises an exception if dimensionality doesn't match.""" collection_name = uuid.uuid4().hex @@ -155,8 +152,8 @@ def test_qdrant_from_texts_raises_error_on_different_dimensionality( ], ) def test_qdrant_from_texts_raises_error_on_different_vector_name( - first_vector_name: Optional[str], - second_vector_name: Optional[str], + first_vector_name: str | None, + second_vector_name: str | None, ) -> None: """Test if Qdrant.from_texts raises an exception if vector name does not match.""" collection_name = uuid.uuid4().hex @@ -211,7 +208,7 @@ def test_qdrant_from_texts_raises_error_on_different_distance() -> None: @pytest.mark.parametrize("vector_name", [None, "custom-vector"]) def test_qdrant_from_texts_recreates_collection_on_force_recreate( - vector_name: Optional[str], + vector_name: str | None, ) -> None: """Test if Qdrant.from_texts recreates the collection even if config mismatches.""" from qdrant_client import QdrantClient diff --git a/libs/partners/qdrant/tests/integration_tests/test_max_marginal_relevance.py b/libs/partners/qdrant/tests/integration_tests/test_max_marginal_relevance.py index 5c76f915095..c69d8867d2b 100644 --- a/libs/partners/qdrant/tests/integration_tests/test_max_marginal_relevance.py +++ b/libs/partners/qdrant/tests/integration_tests/test_max_marginal_relevance.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Optional - import pytest # type: ignore[import-not-found] from langchain_core.documents import Document from qdrant_client import models @@ -21,7 +19,7 @@ def test_qdrant_max_marginal_relevance_search( batch_size: int, content_payload_key: str, metadata_payload_key: str, - vector_name: Optional[str], + vector_name: str | None, ) -> None: """Test end to end construction and MRR search.""" filter_ = models.Filter( diff --git a/libs/partners/qdrant/tests/integration_tests/test_similarity_search.py b/libs/partners/qdrant/tests/integration_tests/test_similarity_search.py index 8471cafd0b9..bcd69898f4c 100644 --- a/libs/partners/qdrant/tests/integration_tests/test_similarity_search.py +++ b/libs/partners/qdrant/tests/integration_tests/test_similarity_search.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Optional - import numpy as np import pytest # type: ignore[import-not-found] from langchain_core.documents import Document @@ -22,7 +20,7 @@ def test_qdrant_similarity_search( batch_size: int, content_payload_key: str, metadata_payload_key: str, - vector_name: Optional[str], + vector_name: str | None, ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -47,7 +45,7 @@ def test_qdrant_similarity_search_by_vector( batch_size: int, content_payload_key: str, metadata_payload_key: str, - vector_name: Optional[str], + vector_name: str | None, ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -73,7 +71,7 @@ def test_qdrant_similarity_search_with_score_by_vector( batch_size: int, content_payload_key: str, metadata_payload_key: str, - vector_name: Optional[str], + vector_name: str | None, ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -97,7 +95,7 @@ def test_qdrant_similarity_search_with_score_by_vector( @pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("vector_name", [None, "my-vector"]) def test_qdrant_similarity_search_filters( - batch_size: int, vector_name: Optional[str] + batch_size: int, vector_name: str | None ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -131,7 +129,7 @@ def test_qdrant_similarity_search_filters( @pytest.mark.parametrize("vector_name", [None, "my-vector"]) def test_qdrant_similarity_search_with_relevance_score_no_threshold( - vector_name: Optional[str], + vector_name: str | None, ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -157,7 +155,7 @@ def test_qdrant_similarity_search_with_relevance_score_no_threshold( @pytest.mark.parametrize("vector_name", [None, "my-vector"]) def test_qdrant_similarity_search_with_relevance_score_with_threshold( - vector_name: Optional[str], + vector_name: str | None, ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -182,7 +180,7 @@ def test_qdrant_similarity_search_with_relevance_score_with_threshold( @pytest.mark.parametrize("vector_name", [None, "my-vector"]) def test_qdrant_similarity_search_with_relevance_score_with_threshold_and_filter( - vector_name: Optional[str], + vector_name: str | None, ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -213,7 +211,7 @@ def test_qdrant_similarity_search_with_relevance_score_with_threshold_and_filter @pytest.mark.parametrize("vector_name", [None, "my-vector"]) def test_qdrant_similarity_search_filters_with_qdrant_filters( - vector_name: Optional[str], + vector_name: str | None, ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -265,7 +263,7 @@ def test_qdrant_similarity_search_with_relevance_scores( batch_size: int, content_payload_key: str, metadata_payload_key: str, - vector_name: Optional[str], + vector_name: str | None, ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] diff --git a/libs/partners/qdrant/uv.lock b/libs/partners/qdrant/uv.lock index addfbeba15d..48008a243d5 100644 --- a/libs/partners/qdrant/uv.lock +++ b/libs/partners/qdrant/uv.lock @@ -205,7 +205,7 @@ name = "coloredlogs" version = "15.0.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "humanfriendly" }, + { name = "humanfriendly", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/cc/c7/eed8f27100517e8c0e6b923d5f0845d0cb99763da6fdee00478f91db7325/coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0", size = 278520, upload-time = "2021-06-11T10:22:45.202Z" } wheels = [ @@ -217,7 +217,7 @@ name = "exceptiongroup" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.12'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } wheels = [ @@ -229,17 +229,17 @@ name = "fastembed" version = "0.7.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "huggingface-hub" }, - { name = "loguru" }, - { name = "mmh3" }, + { name = "huggingface-hub", marker = "python_full_version < '3.13'" }, + { name = "loguru", marker = "python_full_version < '3.13'" }, + { name = "mmh3", marker = "python_full_version < '3.13'" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "onnxruntime" }, - { name = "pillow" }, - { name = "py-rust-stemmers" }, - { name = "requests" }, - { name = "tokenizers" }, - { name = "tqdm" }, + { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and python_full_version < '3.13'" }, + { name = "onnxruntime", marker = "python_full_version < '3.13'" }, + { name = "pillow", marker = "python_full_version < '3.13'" }, + { name = "py-rust-stemmers", marker = "python_full_version < '3.13'" }, + { name = "requests", marker = "python_full_version < '3.13'" }, + { name = "tokenizers", marker = "python_full_version < '3.13'" }, + { name = "tqdm", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/65/f6/e8d3d9d487f95b698c9ff0d04d4e050d8fca9fa4cba58cff60fd519d1976/fastembed-0.7.3.tar.gz", hash = "sha256:04e95eb5ccc706513166c23bf8e5429ed160c5783b7b11514431a77624d480a5", size = 66561, upload-time = "2025-08-29T11:19:46.521Z" } wheels = [ @@ -430,14 +430,14 @@ name = "huggingface-hub" version = "0.35.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock" }, - { name = "fsspec" }, - { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, - { name = "packaging" }, - { name = "pyyaml" }, - { name = "requests" }, - { name = "tqdm" }, - { name = "typing-extensions" }, + { name = "filelock", marker = "python_full_version < '3.13'" }, + { name = "fsspec", marker = "python_full_version < '3.13'" }, + { name = "hf-xet", marker = "(python_full_version < '3.13' and platform_machine == 'aarch64') or (python_full_version < '3.13' and platform_machine == 'amd64') or (python_full_version < '3.13' and platform_machine == 'arm64') or (python_full_version < '3.13' and platform_machine == 'x86_64')" }, + { name = "packaging", marker = "python_full_version < '3.13'" }, + { name = "pyyaml", marker = "python_full_version < '3.13'" }, + { name = "requests", marker = "python_full_version < '3.13'" }, + { name = "tqdm", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/10/7e/a0a97de7c73671863ca6b3f61fa12518caf35db37825e43d63a70956738c/huggingface_hub-0.35.3.tar.gz", hash = "sha256:350932eaa5cc6a4747efae85126ee220e4ef1b54e29d31c3b45c5612ddf0b32a", size = 461798, upload-time = "2025-09-29T14:29:58.625Z" } wheels = [ @@ -449,7 +449,7 @@ name = "humanfriendly" version = "10.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyreadline3", marker = "sys_platform == 'win32'" }, + { name = "pyreadline3", marker = "python_full_version < '3.13' and sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/cc/3f/2c29224acb2e2df4d2046e4c73ee2662023c58ff5b113c4c1adac0886c43/humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc", size = 360702, upload-time = "2021-09-17T21:40:43.31Z" } wheels = [ @@ -506,7 +506,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "1.0.0a6" +version = "1.0.0a8" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, @@ -703,8 +703,8 @@ name = "loguru" version = "0.7.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, - { name = "win32-setctime", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "python_full_version < '3.13' and sys_platform == 'win32'" }, + { name = "win32-setctime", marker = "python_full_version < '3.13' and sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/3a/05/a1dae3dffd1116099471c643b8924f5aa6524411dc6c63fdae648c4f1aca/loguru-0.7.3.tar.gz", hash = "sha256:19480589e77d47b8d85b2c827ad95d49bf31b0dcde16593892eb51dd18706eb6", size = 63559, upload-time = "2024-12-06T11:20:56.608Z" } wheels = [ @@ -1169,13 +1169,13 @@ name = "onnxruntime" version = "1.23.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "coloredlogs" }, - { name = "flatbuffers" }, + { name = "coloredlogs", marker = "python_full_version < '3.13'" }, + { name = "flatbuffers", marker = "python_full_version < '3.13'" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "packaging" }, - { name = "protobuf" }, - { name = "sympy" }, + { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and python_full_version < '3.13'" }, + { name = "packaging", marker = "python_full_version < '3.13'" }, + { name = "protobuf", marker = "python_full_version < '3.13'" }, + { name = "sympy", marker = "python_full_version < '3.13'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/4e/28/4c76b7feca063d47880e76bee235e829bcc4adb87cc26ecff248ece31f17/onnxruntime-1.23.0-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:009bf5ecad107a7f11af8214fcff19e844214887b38c6673bd63a25af2f6121f", size = 17078761, upload-time = "2025-09-25T19:16:41.541Z" }, @@ -2111,7 +2111,7 @@ name = "sympy" version = "1.14.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "mpmath" }, + { name = "mpmath", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921, upload-time = "2025-04-27T18:05:01.611Z" } wheels = [ @@ -2144,7 +2144,7 @@ name = "tokenizers" version = "0.22.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "huggingface-hub" }, + { name = "huggingface-hub", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/1c/46/fb6854cec3278fbfa4a75b50232c77622bc517ac886156e6afbfa4d8fc6e/tokenizers-0.22.1.tar.gz", hash = "sha256:61de6522785310a309b3407bac22d99c4db5dba349935e99e4d15ea2226af2d9", size = 363123, upload-time = "2025-09-19T09:49:23.424Z" } wheels = [ @@ -2208,7 +2208,7 @@ name = "tqdm" version = "4.67.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "python_full_version < '3.13' and sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" } wheels = [ diff --git a/libs/partners/xai/langchain_xai/chat_models.py b/libs/partners/xai/langchain_xai/chat_models.py index e6037e04692..8283b3362a7 100644 --- a/libs/partners/xai/langchain_xai/chat_models.py +++ b/libs/partners/xai/langchain_xai/chat_models.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Literal, TypeAlias import openai from langchain_core.messages import AIMessageChunk @@ -19,9 +19,8 @@ if TYPE_CHECKING: from langchain_core.outputs import ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable -_BM = TypeVar("_BM", bound=BaseModel) -_DictOrPydanticClass = Union[dict[str, Any], type[_BM], type] -_DictOrPydantic = Union[dict, _BM] +_DictOrPydanticClass: TypeAlias = dict[str, Any] | type[BaseModel] | type +_DictOrPydantic: TypeAlias = dict | BaseModel class ChatXAI(BaseChatOpenAI): # type: ignore[override] @@ -31,7 +30,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] for more nuanced details on the API's behavior and supported parameters. Setup: - Install ``langchain-xai`` and set environment variable ``XAI_API_KEY``. + Install `langchain-xai` and set environment variable `XAI_API_KEY`. .. code-block:: bash @@ -43,9 +42,9 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] model: str Name of model to use. temperature: float - Sampling temperature between `0` and ``2``. Higher values mean more random completions, - while lower values (like ``0.2``) mean more focused and deterministic completions. - (Default: ``1``.) + Sampling temperature between `0` and `2`. Higher values mean more random completions, + while lower values (like `0.2`) mean more focused and deterministic completions. + (Default: `1`.) max_tokens: Optional[int] Max number of tokens to generate. Refer to your `model's documentation `__ for the maximum number of tokens it can generate. @@ -58,7 +57,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] max_retries: int Max number of retries. api_key: Optional[str] - xAI API key. If not passed in will be read from env var ``XAI_API_KEY``. + xAI API key. If not passed in will be read from env var `XAI_API_KEY`. Instantiate: .. code-block:: python @@ -170,9 +169,9 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] If provided, reasoning content is returned under the ``additional_kwargs`` field of the AIMessage or AIMessageChunk. - If supported, reasoning effort can be specified in the model constructor's ``extra_body`` + If supported, reasoning effort can be specified in the model constructor's `extra_body` argument, which will control the amount of reasoning the model does. The value can be one of - ``'low'`` or ``'high'``. + `'low'` or `'high'`. .. code-block:: python @@ -236,7 +235,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] single chunk, instead of being streamed across chunks. Tool choice can be controlled by setting the `tool_choice` parameter in the model - constructor's ``extra_body`` argument. For example, to disable tool / function calling: + constructor's `extra_body` argument. For example, to disable tool / function calling: .. code-block:: python llm = ChatXAI(model="grok-4", extra_body={"tool_choice": "none"}) @@ -399,21 +398,21 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] model_name: str = Field(default="grok-4", alias="model") """Model name to use.""" - xai_api_key: Optional[SecretStr] = Field( + xai_api_key: SecretStr | None = Field( alias="api_key", default_factory=secret_from_env("XAI_API_KEY", default=None), ) """xAI API key. - Automatically read from env variable ``XAI_API_KEY`` if not provided. + Automatically read from env variable `XAI_API_KEY` if not provided. """ xai_api_base: str = Field(default="https://api.x.ai/v1/") """Base URL path for API requests.""" - search_parameters: Optional[dict[str, Any]] = None + search_parameters: dict[str, Any] | None = None """Parameters for search requests. Example: ``{"mode": "auto"}``.""" - openai_api_key: Optional[SecretStr] = None - openai_api_base: Optional[str] = None + openai_api_key: SecretStr | None = None + openai_api_base: str | None = None model_config = ConfigDict( populate_by_name=True, @@ -457,7 +456,7 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] def _get_ls_params( self, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, # noqa: ANN401 ) -> LangSmithParams: """Get the parameters used to invoke the model.""" @@ -525,8 +524,8 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] def _create_chat_result( self, - response: Union[dict, openai.BaseModel], - generation_info: Optional[dict] = None, + response: dict | openai.BaseModel, + generation_info: dict | None = None, ) -> ChatResult: rtn = super()._create_chat_result(response, generation_info) @@ -549,8 +548,8 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] self, chunk: dict, default_chunk_class: type, - base_generation_info: Optional[dict], - ) -> Optional[ChatGenerationChunk]: + base_generation_info: dict | None, + ) -> ChatGenerationChunk | None: generation_chunk = super()._convert_chunk_to_generation_chunk( chunk, default_chunk_class, @@ -576,13 +575,13 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] def with_structured_output( self, - schema: Optional[_DictOrPydanticClass] = None, + schema: _DictOrPydanticClass | None = None, *, method: Literal[ "function_calling", "json_mode", "json_schema" ] = "function_calling", include_raw: bool = False, - strict: Optional[bool] = None, + strict: bool | None = None, **kwargs: Any, # noqa: ANN401 ) -> Runnable[LanguageModelInput, _DictOrPydantic]: """Model wrapper that returns outputs formatted to match the given schema. diff --git a/libs/partners/xai/pyproject.toml b/libs/partners/xai/pyproject.toml index c1dfeb7187e..e4ffa34a66a 100644 --- a/libs/partners/xai/pyproject.toml +++ b/libs/partners/xai/pyproject.toml @@ -70,8 +70,6 @@ ignore = [ "S112", # Rarely useful "RUF012", # Doesn't play well with Pydantic "SLF001", # Private member access - "UP007", # pyupgrade: non-pep604-annotation-union - "UP045", # pyupgrade: non-pep604-annotation-optional "FIX", # TODOs "TD", # TODOs ] diff --git a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py index 0a742201efb..dada7d5c4bc 100644 --- a/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/xai/tests/integration_tests/test_chat_models_standard.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from langchain_core.messages import AIMessageChunk, BaseMessageChunk from langchain_core.rate_limiters import InMemoryRateLimiter @@ -52,7 +52,7 @@ def test_reasoning_content() -> None: assert response.additional_kwargs["reasoning_content"] # Test streaming - full: Optional[BaseMessageChunk] = None + full: BaseMessageChunk | None = None for chunk in chat_model.stream("What is 3^3?"): full = chunk if full is None else full + chunk assert isinstance(full, AIMessageChunk)