community: Remove no-untyped-def escapes

This commit is contained in:
cbornet 2025-04-16 14:55:33 +02:00
parent cf2697ec53
commit a60f82b1e2
42 changed files with 424 additions and 188 deletions

View File

@ -575,8 +575,8 @@ class ChatBaichuan(BaseChatModel):
)
return res
def _create_payload_parameters( # type: ignore[no-untyped-def]
self, messages: List[BaseMessage], **kwargs
def _create_payload_parameters(
self, messages: List[BaseMessage], **kwargs: Any
) -> Dict[str, Any]:
parameters = {**self._default_params, **kwargs}
temperature = parameters.pop("temperature", 0.3)
@ -600,7 +600,7 @@ class ChatBaichuan(BaseChatModel):
return payload
def _create_headers_parameters(self, **kwargs) -> Dict[str, Any]: # type: ignore[no-untyped-def]
def _create_headers_parameters(self, **kwargs: Any) -> Dict[str, Any]:
parameters = {**self._default_params, **kwargs}
default_headers = parameters.pop("headers", {})
api_key = ""

View File

@ -439,8 +439,8 @@ class MiniMaxChat(BaseChatModel):
}
return ChatResult(generations=generations, llm_output=llm_output)
def _create_payload_parameters( # type: ignore[no-untyped-def]
self, messages: List[BaseMessage], is_stream: bool = False, **kwargs
def _create_payload_parameters(
self, messages: List[BaseMessage], is_stream: bool = False, **kwargs: Any
) -> Dict[str, Any]:
"""Create API request body parameters."""
message_dicts = [_convert_message_to_dict(m) for m in messages]

View File

@ -1,5 +1,8 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Union
from types import TracebackType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing_extensions import Self
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
@ -65,10 +68,15 @@ class CHMParser(object):
self.file = chm.CHMFile()
self.file.LoadCHM(path)
def __enter__(self): # type: ignore[no-untyped-def]
def __enter__(self) -> Self:
return self
def __exit__(self, exc_type, exc_value, traceback): # type: ignore[no-untyped-def]
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if self.file:
self.file.CloseCHM()

View File

@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Iterator, Optional, Sequence, Union
from typing import TYPE_CHECKING, Iterator, Optional, Sequence, Union
from langchain_core.documents import Document
@ -8,6 +8,9 @@ from langchain_community.document_loaders.base import BaseLoader
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
import mwxml
class MWDumpLoader(BaseLoader):
"""Load `MediaWiki` dump from an `XML` file.
@ -60,7 +63,7 @@ class MWDumpLoader(BaseLoader):
self.skip_redirects = skip_redirects
self.stop_on_error = stop_on_error
def _load_dump_file(self): # type: ignore[no-untyped-def]
def _load_dump_file(self) -> "mwxml.Dump":
try:
import mwxml
except ImportError as e:
@ -70,7 +73,7 @@ class MWDumpLoader(BaseLoader):
return mwxml.Dump.from_file(open(self.file_path, encoding=self.encoding))
def _load_single_page_from_dump(self, page) -> Document: # type: ignore[no-untyped-def, return]
def _load_single_page_from_dump(self, page: "mwxml.Page") -> Document: # type: ignore[return]
"""Parse a single page."""
try:
import mwparserfromhell

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Iterator, List, Optional
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
from langchain_core.documents import Document
@ -34,20 +34,24 @@ class AzureAIDocumentIntelligenceParser(BaseBlobParser):
kwargs = {}
if api_key is None and azure_credential is None:
credential: Union[AzureKeyCredential, TokenCredential]
if azure_credential:
if api_key is not None:
raise ValueError(
"Only one of api_key or azure_credential should be provided."
)
credential = azure_credential
elif api_key is not None:
credential = AzureKeyCredential(api_key)
else:
raise ValueError("Either api_key or azure_credential must be provided.")
if api_key and azure_credential:
raise ValueError(
"Only one of api_key or azure_credential should be provided."
)
if api_version is not None:
kwargs["api_version"] = api_version
self.client = DocumentIntelligenceClient(
endpoint=api_endpoint,
credential=azure_credential or AzureKeyCredential(api_key),
credential=credential,
headers={"x-ms-useragent": "langchain-parser/1.0.0"},
**kwargs,
)

View File

@ -169,5 +169,5 @@ class TinyAsyncGradientEmbeddingClient: #: :meta private:
It might be entirely removed in the future.
"""
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise ValueError("Deprecated,TinyAsyncGradientEmbeddingClient was removed.")

View File

@ -1,10 +1,13 @@
from enum import Enum
from typing import Any, Dict, Iterator, List, Mapping, Optional
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Mapping, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.utils import pre_init
from pydantic import BaseModel, ConfigDict
if TYPE_CHECKING:
import oci
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
@ -122,12 +125,14 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings):
client_kwargs.pop("signer", None)
elif values["auth_type"] == OCIAuthType(2).name:
def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
def make_security_token_signer(
oci_config: dict[str, Any],
) -> "oci.auth.signers.SecurityTokenSigner":
pk = oci.signer.load_private_key_from_file(
oci_config.get("key_file"), None
)
with open(
oci_config.get("security_token_file"), encoding="utf-8"
str(oci_config.get("security_token_file")), encoding="utf-8"
) as f:
st_string = f.read()
return oci.auth.signers.SecurityTokenSigner(st_string, pk)

View File

@ -159,18 +159,20 @@ def _create_retry_decorator(llm: YandexGPTEmbeddings) -> Callable[[Any], Any]:
)
def _embed_with_retry(llm: YandexGPTEmbeddings, **kwargs: Any) -> Any:
def _embed_with_retry(llm: YandexGPTEmbeddings, **kwargs: Any) -> list[list[float]]:
"""Use tenacity to retry the embedding call."""
retry_decorator = _create_retry_decorator(llm)
@retry_decorator
def _completion_with_retry(**_kwargs: Any) -> Any:
def _completion_with_retry(**_kwargs: Any) -> list[list[float]]:
return _make_request(llm, **_kwargs)
return _completion_with_retry(**kwargs)
def _make_request(self: YandexGPTEmbeddings, texts: List[str], **kwargs): # type: ignore[no-untyped-def]
def _make_request(
self: YandexGPTEmbeddings, texts: List[str], **kwargs: Any
) -> list[list[float]]:
try:
import grpc

View File

@ -93,6 +93,7 @@ class OntotextGraphDBGraph:
self.graph = rdflib.Graph(store, identifier=None, bind_namespaces="none")
self._check_connectivity()
ontology_schema_graph: "rdflib.Graph"
if local_file:
ontology_schema_graph = self._load_ontology_schema_from_file(
local_file,
@ -140,7 +141,9 @@ class OntotextGraphDBGraph:
)
@staticmethod
def _load_ontology_schema_from_file(local_file: str, local_file_format: str = None): # type: ignore[no-untyped-def, assignment]
def _load_ontology_schema_from_file(
local_file: str, local_file_format: Optional[str] = None
) -> "rdflib.ConjunctiveGraph": # type: ignore[assignment]
"""
Parse the ontology schema statements from the provided file
"""
@ -177,7 +180,7 @@ class OntotextGraphDBGraph:
"Invalid query type. Only CONSTRUCT queries are supported."
)
def _load_ontology_schema_with_query(self, query: str): # type: ignore[no-untyped-def]
def _load_ontology_schema_with_query(self, query: str) -> "rdflib.Graph":
"""
Execute the query for collecting the ontology schema statements
"""
@ -188,6 +191,9 @@ class OntotextGraphDBGraph:
except ParserError as e:
raise ValueError(f"Generated SPARQL statement is invalid\n{e}")
if not results.graph:
raise ValueError("Missing graph in results.")
return results.graph
@property

View File

@ -77,7 +77,7 @@ class TigerGraph(GraphStore):
"""
return self._conn.getSchema(force=True)
def refresh_schema(self): # type: ignore[no-untyped-def]
def refresh_schema(self) -> None:
self.generate_schema()
def query(self, query: str) -> Dict[str, Any]: # type: ignore[override]

View File

@ -136,12 +136,14 @@ class OCIGenAIBase(BaseModel, ABC):
client_kwargs.pop("signer", None)
elif values["auth_type"] == OCIAuthType(2).name:
def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
def make_security_token_signer(
oci_config: dict[str, Any],
) -> "oci.auth.signers.SecurityTokenSigner":
pk = oci.signer.load_private_key_from_file(
oci_config.get("key_file"), None
)
with open(
oci_config.get("security_token_file"), encoding="utf-8"
str(oci_config.get("security_token_file")), encoding="utf-8"
) as f:
st_string = f.read()
return oci.auth.signers.SecurityTokenSigner(st_string, pk)

View File

@ -77,8 +77,11 @@ class ZapierNLAWrapper(BaseModel):
response.raise_for_status()
return await response.json()
def _create_action_payload( # type: ignore[no-untyped-def]
self, instructions: str, params: Optional[Dict] = None, preview_only=False
def _create_action_payload(
self,
instructions: str,
params: Optional[Dict] = None,
preview_only: bool = False,
) -> Dict:
"""Create a payload for an action."""
data = params if params else {}
@ -95,12 +98,12 @@ class ZapierNLAWrapper(BaseModel):
"""Create a url for an action."""
return self.zapier_nla_api_base + f"exposed/{action_id}/execute/"
def _create_action_request( # type: ignore[no-untyped-def]
def _create_action_request(
self,
action_id: str,
instructions: str,
params: Optional[Dict] = None,
preview_only=False,
preview_only: bool = False,
) -> Request:
data = self._create_action_payload(instructions, params, preview_only)
return Request(
@ -259,39 +262,37 @@ class ZapierNLAWrapper(BaseModel):
)
return response["result"]
def run_as_str(self, *args, **kwargs) -> str: # type: ignore[no-untyped-def]
def run_as_str(self, *args: Any, **kwargs: Any) -> str:
"""Same as run, but returns a stringified version of the JSON for
insertting back into an LLM."""
data = self.run(*args, **kwargs)
return json.dumps(data)
async def arun_as_str(self, *args, **kwargs) -> str: # type: ignore[no-untyped-def]
async def arun_as_str(self, *args: Any, **kwargs: Any) -> str:
"""Same as run, but returns a stringified version of the JSON for
insertting back into an LLM."""
data = await self.arun(*args, **kwargs)
return json.dumps(data)
def preview_as_str(self, *args, **kwargs) -> str: # type: ignore[no-untyped-def]
def preview_as_str(self, *args: Any, **kwargs: Any) -> str:
"""Same as preview, but returns a stringified version of the JSON for
insertting back into an LLM."""
data = self.preview(*args, **kwargs)
return json.dumps(data)
async def apreview_as_str( # type: ignore[no-untyped-def]
self, *args, **kwargs
) -> str:
async def apreview_as_str(self, *args: Any, **kwargs: Any) -> str:
"""Same as preview, but returns a stringified version of the JSON for
insertting back into an LLM."""
data = await self.apreview(*args, **kwargs)
return json.dumps(data)
def list_as_str(self) -> str: # type: ignore[no-untyped-def]
def list_as_str(self) -> str:
"""Same as list, but returns a stringified version of the JSON for
insertting back into an LLM."""
actions = self.list()
return json.dumps(actions)
async def alist_as_str(self) -> str: # type: ignore[no-untyped-def]
async def alist_as_str(self) -> str:
"""Same as list, but returns a stringified version of the JSON for
insertting back into an LLM."""
actions = await self.alist()

View File

@ -7,6 +7,7 @@ import json
import logging
import time
import uuid
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
@ -22,6 +23,7 @@ from typing import (
Type,
Union,
cast,
overload,
)
import numpy as np
@ -80,6 +82,54 @@ FIELDS_METADATA = get_from_env(
MAX_UPLOAD_BATCH_SIZE = 1000
@overload
def _get_search_client(
endpoint: str,
index_name: str,
key: Optional[str] = None,
azure_ad_access_token: Optional[str] = None,
semantic_configuration_name: Optional[str] = None,
fields: Optional[List[SearchField]] = None,
vector_search: Optional[VectorSearch] = None,
semantic_configurations: Optional[
Union[SemanticConfiguration, List[SemanticConfiguration]]
] = None,
scoring_profiles: Optional[List[ScoringProfile]] = None,
default_scoring_profile: Optional[str] = None,
default_fields: Optional[List[SearchField]] = None,
user_agent: Optional[str] = "langchain-comm-python-azure-search",
cors_options: Optional[CorsOptions] = None,
async_: Literal[False] = False,
additional_search_client_options: Optional[Dict[str, Any]] = None,
azure_credential: Optional[TokenCredential] = None,
azure_async_credential: Optional[AsyncTokenCredential] = None,
) -> Union[SearchClient]: ...
@overload
def _get_search_client(
endpoint: str,
index_name: str,
key: Optional[str] = None,
azure_ad_access_token: Optional[str] = None,
semantic_configuration_name: Optional[str] = None,
fields: Optional[List[SearchField]] = None,
vector_search: Optional[VectorSearch] = None,
semantic_configurations: Optional[
Union[SemanticConfiguration, List[SemanticConfiguration]]
] = None,
scoring_profiles: Optional[List[ScoringProfile]] = None,
default_scoring_profile: Optional[str] = None,
default_fields: Optional[List[SearchField]] = None,
user_agent: Optional[str] = "langchain-comm-python-azure-search",
cors_options: Optional[CorsOptions] = None,
async_: Literal[True] = True,
additional_search_client_options: Optional[Dict[str, Any]] = None,
azure_credential: Optional[TokenCredential] = None,
azure_async_credential: Optional[AsyncTokenCredential] = None,
) -> Union[AsyncSearchClient]: ...
def _get_search_client(
endpoint: str,
index_name: str,
@ -102,6 +152,7 @@ def _get_search_client(
azure_async_credential: Optional[AsyncTokenCredential] = None,
) -> Union[SearchClient, AsyncSearchClient]:
from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential
@ -139,22 +190,54 @@ def _get_search_client(
) -> AccessToken:
return self._token
class AsyncTokenCredentialWrapper(AsyncTokenCredential):
def __init__(self, credential: TokenCredential):
self._credential = credential
async def get_token(
self,
*scopes: str,
claims: Optional[str] = None,
tenant_id: Optional[str] = None,
enable_cae: bool = False,
**kwargs: Any,
) -> AccessToken:
return self._credential.get_token(
*scopes,
claims=claims,
tenant_id=tenant_id,
enable_cae=enable_cae,
**kwargs,
)
async def close(self) -> None:
pass
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
pass
additional_search_client_options = additional_search_client_options or {}
default_fields = default_fields or []
credential: Union[AzureKeyCredential, TokenCredential, InteractiveBrowserCredential]
credential: Union[AzureKeyCredential, TokenCredential]
async_credential: Union[AzureKeyCredential, AsyncTokenCredential]
# Determine the appropriate credential to use
if key is not None:
if key.upper() == "INTERACTIVE":
credential = InteractiveBrowserCredential()
credential = cast("TokenCredential", InteractiveBrowserCredential())
credential.get_token("https://search.azure.com/.default")
async_credential = credential
async_credential = AsyncTokenCredentialWrapper(credential)
else:
credential = AzureKeyCredential(key)
async_credential = credential
elif azure_ad_access_token is not None:
credential = AzureBearerTokenCredential(azure_ad_access_token)
async_credential = credential
async_credential = AsyncTokenCredentialWrapper(credential)
else:
credential = azure_credential or DefaultAzureCredential()
async_credential = azure_async_credential or AsyncDefaultAzureCredential()
@ -1121,7 +1204,7 @@ class AzureSearch(VectorStore):
search_text=text_query,
vector_queries=[
VectorizedQuery(
vector=np.array(embedding, dtype=np.float32).tolist(),
vector=embedding,
k_nearest_neighbors=k,
fields=FIELDS_CONTENT_VECTOR,
)
@ -1157,7 +1240,7 @@ class AzureSearch(VectorStore):
search_text=text_query,
vector_queries=[
VectorizedQuery(
vector=np.array(embedding, dtype=np.float32).tolist(),
vector=embedding,
k_nearest_neighbors=k,
fields=FIELDS_CONTENT_VECTOR,
)
@ -1302,7 +1385,7 @@ class AzureSearch(VectorStore):
search_text=query,
vector_queries=[
VectorizedQuery(
vector=np.array(self.embed_query(query), dtype=np.float32).tolist(),
vector=self.embed_query(query),
k_nearest_neighbors=k,
fields=FIELDS_CONTENT_VECTOR,
)
@ -1390,7 +1473,7 @@ class AzureSearch(VectorStore):
search_text=query,
vector_queries=[
VectorizedQuery(
vector=np.array(vector, dtype=np.float32).tolist(),
vector=vector,
k_nearest_neighbors=k,
fields=FIELDS_CONTENT_VECTOR,
)
@ -1754,7 +1837,7 @@ async def _aresults_to_documents(
async def _areorder_results_with_maximal_marginal_relevance(
results: SearchItemPaged[Dict],
results: AsyncSearchItemPaged[Dict],
query_embedding: np.ndarray,
lambda_mult: float = 0.5,
k: int = 4,

View File

@ -231,7 +231,7 @@ class BigQueryVectorSearch(VectorStore):
self._logger.debug("Vector index already exists.")
self._have_index = True
def _create_index_in_background(self): # type: ignore[no-untyped-def]
def _create_index_in_background(self) -> None:
if self._have_index or self._creating_index:
# Already have an index or in the process of creating one.
return
@ -240,7 +240,7 @@ class BigQueryVectorSearch(VectorStore):
thread = Thread(target=self._create_index, daemon=True)
thread.start()
def _create_index(self): # type: ignore[no-untyped-def]
def _create_index(self) -> None:
from google.api_core.exceptions import ClientError
table = self.bq_client.get_table(self.vectors_table)

View File

@ -943,7 +943,7 @@ class DeepLake(VectorStore):
return self.vectorstore.dataset
@classmethod
def _validate_kwargs(cls, kwargs, method_name): # type: ignore[no-untyped-def]
def _validate_kwargs(cls, kwargs: Any, method_name: str) -> None:
if kwargs:
valid_items = cls._get_valid_args(method_name)
unsupported_items = cls._get_unsupported_items(kwargs, valid_items)
@ -955,14 +955,14 @@ class DeepLake(VectorStore):
)
@classmethod
def _get_valid_args(cls, method_name): # type: ignore[no-untyped-def]
def _get_valid_args(cls, method_name: str) -> list[str]:
if method_name == "search":
return cls._valid_search_kwargs
else:
return []
@staticmethod
def _get_unsupported_items(kwargs, valid_items): # type: ignore[no-untyped-def]
def _get_unsupported_items(kwargs: Any, valid_items: list[str]) -> Optional[str]:
kwargs = {k: v for k, v in kwargs.items() if k not in valid_items}
unsupported_items = None
if kwargs:

View File

@ -15,7 +15,6 @@ from typing import (
Optional,
Pattern,
Tuple,
Type,
)
import numpy as np
@ -23,6 +22,7 @@ from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
from langchain_core.vectorstores import VectorStore
from typing_extensions import Self
from langchain_community.vectorstores.utils import (
DistanceStrategy,
@ -149,7 +149,7 @@ class HanaDB(VectorStore):
for column_name in self.specific_metadata_columns:
self._check_column(self.table_name, column_name)
def _table_exists(self, table_name) -> bool: # type: ignore[no-untyped-def]
def _table_exists(self, table_name: str) -> bool:
sql_str = (
"SELECT COUNT(*) FROM SYS.TABLES WHERE SCHEMA_NAME = CURRENT_SCHEMA"
" AND TABLE_NAME = ?"
@ -165,9 +165,13 @@ class HanaDB(VectorStore):
cur.close()
return False
def _check_column( # type: ignore[no-untyped-def]
self, table_name, column_name, column_type=None, column_length=None
):
def _check_column(
self,
table_name: str,
column_name: str,
column_type: Optional[list[str]] = None,
column_length: Optional[int] = None,
) -> None:
sql_str = (
"SELECT DATA_TYPE_NAME, LENGTH FROM SYS.TABLE_COLUMNS WHERE "
"SCHEMA_NAME = CURRENT_SCHEMA "
@ -406,8 +410,8 @@ class HanaDB(VectorStore):
return []
@classmethod
def from_texts( # type: ignore[no-untyped-def, override]
cls: Type[HanaDB],
def from_texts( # type: ignore[override]
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
@ -420,7 +424,7 @@ class HanaDB(VectorStore):
vector_column_length: int = default_vector_column_length,
*,
specific_metadata_columns: Optional[List[str]] = None,
):
) -> Self:
"""Create a HanaDB instance from raw documents.
This is a user-friendly interface that:
1. Embeds documents.
@ -512,12 +516,12 @@ class HanaDB(VectorStore):
)
order_str = f" order by CS {HANA_DISTANCE_FUNCTION[self.distance_strategy][1]}"
where_str, query_tuple = self._create_where_by_filter(filter)
query_tuple = (embedding_as_str,) + tuple(query_tuple)
query_params = (embedding_as_str,) + tuple(query_tuple)
sql_str = sql_str + where_str
sql_str = sql_str + order_str
try:
cur = self.connection.cursor()
cur.execute(sql_str, query_tuple)
cur.execute(sql_str, query_params)
if cur.has_result_set():
rows = cur.fetchall()
for row in rows:
@ -567,15 +571,15 @@ class HanaDB(VectorStore):
)
return [doc for doc, _ in docs_and_scores]
def _create_where_by_filter(self, filter): # type: ignore[no-untyped-def]
query_tuple = []
def _create_where_by_filter(self, filter: Optional[dict]) -> Tuple[str, list[Any]]:
query_tuple: list[Any] = []
where_str = ""
if filter:
where_str, query_tuple = self._process_filter_object(filter)
where_str = " WHERE " + where_str
return where_str, query_tuple
def _process_filter_object(self, filter): # type: ignore[no-untyped-def]
def _process_filter_object(self, filter: Optional[dict]) -> Tuple[str, list[Any]]:
query_tuple = []
where_str = ""
if filter:

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import logging
from typing import Any, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
from uuid import uuid4
import numpy as np
@ -12,6 +12,9 @@ from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
from pymilvus.orm.mutation import MutationResult
logger = logging.getLogger(__name__)
DEFAULT_MILVUS_CONNECTION = {
@ -938,9 +941,9 @@ class Milvus(VectorStore):
ret.append(documents[x])
return ret
def delete( # type: ignore[no-untyped-def]
self, ids: Optional[List[str]] = None, expr: Optional[str] = None, **kwargs: str
):
def delete(
self, ids: Optional[List[str]] = None, expr: Optional[str] = None, **kwargs: Any
) -> MutationResult:
"""Delete by vector ID or boolean expression.
Refer to [Milvus documentation](https://milvus.io/docs/delete_data.md)
for notes and examples of expressions.

View File

@ -12,9 +12,11 @@ from typing import (
Generator,
Iterable,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
)
import numpy as np
@ -750,7 +752,9 @@ class PGVector(VectorStore):
else:
raise NotImplementedError()
def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyped-def]
def _create_filter_clause_deprecated(
self, key: str, value: dict[str, Any]
) -> SQLColumnExpression:
"""Deprecated functionality.
This is for backwards compatibility with the JSON based schema for metadata.
@ -819,7 +823,7 @@ class PGVector(VectorStore):
return filter_by_metadata
def _create_filter_clause_json_deprecated(
self, filter: Any
self, filter: Mapping[str, Union[str, dict[str, Any]]]
) -> List[SQLColumnExpression]:
"""Convert filters from IR to SQL clauses.

View File

@ -2,12 +2,16 @@ import importlib
import os
import tempfile
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from pydantic import ConfigDict
from typing_extensions import Self
if TYPE_CHECKING:
from thirdai import neural_db as ndb
class NeuralDBVectorStore(VectorStore):
@ -25,10 +29,10 @@ class NeuralDBVectorStore(VectorStore):
vectorstore = NeuralDBVectorStore(db=db)
"""
def __init__(self, db: Any) -> None:
def __init__(self, db: "ndb.NeuralDB") -> None:
self.db = db
db: Any = None #: :meta private:
db: "ndb.NeuralDB" = None #: :meta private:
"""NeuralDB instance"""
model_config = ConfigDict(
@ -36,7 +40,7 @@ class NeuralDBVectorStore(VectorStore):
)
@staticmethod
def _verify_thirdai_library(thirdai_key: Optional[str] = None): # type: ignore[no-untyped-def]
def _verify_thirdai_library(thirdai_key: Optional[str] = None) -> None:
try:
from thirdai import licensing
@ -50,11 +54,11 @@ class NeuralDBVectorStore(VectorStore):
)
@classmethod
def from_scratch( # type: ignore[no-untyped-def, no-untyped-def]
def from_scratch(
cls,
thirdai_key: Optional[str] = None,
**model_kwargs,
):
**model_kwargs: Any,
) -> Self:
"""
Create a NeuralDBVectorStore from scratch.
@ -84,11 +88,11 @@ class NeuralDBVectorStore(VectorStore):
return cls(db=ndb.NeuralDB(**model_kwargs)) # type: ignore[call-arg]
@classmethod
def from_checkpoint( # type: ignore[no-untyped-def]
def from_checkpoint(
cls,
checkpoint: Union[str, Path],
thirdai_key: Optional[str] = None,
):
) -> Self:
"""
Create a NeuralDBVectorStore with a base model from a saved checkpoint
@ -163,13 +167,13 @@ class NeuralDBVectorStore(VectorStore):
offset = self.db._savable_state.documents.get_source_by_id(source_id)[1]
return [str(offset + i) for i in range(len(texts))] # type: ignore[arg-type]
def insert( # type: ignore[no-untyped-def, no-untyped-def]
def insert(
self,
sources: List[Any],
sources: list[Union[str, "ndb.Document"]],
train: bool = True,
fast_mode: bool = True,
**kwargs,
):
**kwargs: Any,
) -> list[str]:
"""Inserts files / document sources into the vectorstore.
Args:
@ -180,14 +184,16 @@ class NeuralDBVectorStore(VectorStore):
Defaults to True.
"""
sources = self._preprocess_sources(sources)
self.db.insert(
return self.db.insert(
sources=sources,
train=train,
fast_approximation=fast_mode,
**kwargs,
)
def _preprocess_sources(self, sources): # type: ignore[no-untyped-def]
def _preprocess_sources(
self, sources: list[Union[str, "ndb.Document"]]
) -> list["ndb.Document"]:
"""Checks if the provided sources are string paths. If they are, convert
to NeuralDB document objects.
@ -219,7 +225,7 @@ class NeuralDBVectorStore(VectorStore):
)
return preprocessed_sources
def upvote(self, query: str, document_id: Union[int, str]): # type: ignore[no-untyped-def]
def upvote(self, query: str, document_id: Union[int, str]) -> None:
"""The vectorstore upweights the score of a document for a specific query.
This is useful for fine-tuning the vectorstore to user behavior.
@ -229,7 +235,7 @@ class NeuralDBVectorStore(VectorStore):
"""
self.db.text_to_result(query, int(document_id))
def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]): # type: ignore[no-untyped-def]
def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]) -> None:
"""Given a batch of (query, document id) pairs, the vectorstore upweights
the scores of the document for the corresponding queries.
This is useful for fine-tuning the vectorstore to user behavior.
@ -242,7 +248,7 @@ class NeuralDBVectorStore(VectorStore):
[(query, int(doc_id)) for query, doc_id in query_id_pairs]
)
def associate(self, source: str, target: str): # type: ignore[no-untyped-def]
def associate(self, source: str, target: str) -> None:
"""The vectorstore associates a source phrase with a target phrase.
When the vectorstore sees the source phrase, it will also consider results
that are relevant to the target phrase.
@ -253,7 +259,7 @@ class NeuralDBVectorStore(VectorStore):
"""
self.db.associate(source, target)
def associate_batch(self, text_pairs: List[Tuple[str, str]]): # type: ignore[no-untyped-def]
def associate_batch(self, text_pairs: List[Tuple[str, str]]) -> None:
"""Given a batch of (source, target) pairs, the vectorstore associates
each source phrase with the corresponding target phrase.
@ -291,7 +297,7 @@ class NeuralDBVectorStore(VectorStore):
except Exception as e:
raise ValueError(f"Error while retrieving documents: {e}") from e
def save(self, path: str): # type: ignore[no-untyped-def]
def save(self, path: str) -> None:
"""Saves a NeuralDB instance to disk. Can be loaded into memory by
calling NeuralDB.from_checkpoint(path)
@ -325,10 +331,10 @@ class NeuralDBClientVectorStore(VectorStore):
"""
def __init__(self, db: Any) -> None:
def __init__(self, db: "ndb.NeuralDBClient") -> None:
self.db = db
db: Any = None #: :meta private:
db: "ndb.NeuralDBClient" = None #: :meta private:
"""NeuralDB Client instance"""
model_config = ConfigDict(
@ -362,7 +368,7 @@ class NeuralDBClientVectorStore(VectorStore):
except Exception as e:
raise ValueError(f"Error while retrieving documents: {e}") from e
def insert(self, documents: List[Dict[str, Any]]): # type: ignore[no-untyped-def, no-untyped-def]
def insert(self, documents: List[Dict[str, Any]]) -> Any:
"""
Inserts documents into the VectorStore and return the corresponding Sources.
@ -446,7 +452,7 @@ class NeuralDBClientVectorStore(VectorStore):
"""
return self.db.insert(documents)
def remove_documents(self, source_ids: List[str]): # type: ignore[no-untyped-def]
def remove_documents(self, source_ids: list[str]) -> None:
"""
Deletes documents from the VectorStore using source ids.

View File

@ -8,6 +8,7 @@ import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from typing_extensions import Self
from langchain_community.vectorstores.utils import maximal_marginal_relevance
@ -31,7 +32,14 @@ class VikingDBConfig(object):
scheme(str):http or https, defaulting to http.
"""
def __init__(self, host="host", region="region", ak="ak", sk="sk", scheme="http"): # type: ignore[no-untyped-def]
def __init__(
self,
host: str = "host",
region: str = "region",
ak: str = "ak",
sk: str = "sk",
scheme: str = "http",
) -> None:
self.host = host
self.region = region
self.ak = ak
@ -397,7 +405,7 @@ class VikingDB(VectorStore):
self.collection.delete_data(ids) # type: ignore[union-attr]
@classmethod
def from_texts( # type: ignore[no-untyped-def, override]
def from_texts( # type: ignore[override]
cls,
texts: List[str],
embedding: Embeddings,
@ -407,7 +415,7 @@ class VikingDB(VectorStore):
index_params: Optional[dict] = None,
drop_old: bool = False,
**kwargs: Any,
):
) -> Self:
"""Create a collection, indexes it and insert data."""
if connection_args is None:
raise Exception("VikingDBConfig does not exists")

View File

@ -57,7 +57,7 @@ def test_add_messages() -> None:
assert len(message_store_another.messages) == 0
def test_tidb_recent_chat_message(): # type: ignore[no-untyped-def]
def test_tidb_recent_chat_message() -> None:
"""Test the TiDBChatMessageHistory with earliest_time parameter."""
import time
from datetime import datetime

View File

@ -5,13 +5,16 @@ You can get a list of models from the bedrock client by running 'bedrock_models(
"""
import os
from typing import Any
from typing import TYPE_CHECKING, Any
import pytest
from langchain_core.callbacks import AsyncCallbackHandler
from langchain_community.llms.bedrock import Bedrock
if TYPE_CHECKING:
from botocore.client import BaseClient
# this is the guardrails id for the model you want to test
GUARDRAILS_ID = os.environ.get("GUARDRAILS_ID", "7jarelix77")
# this is the guardrails version for the model you want to test
@ -37,12 +40,12 @@ class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
if reason == "GUARDRAIL_INTERVENED":
self.guardrails_intervened = True
def get_response(self): # type: ignore[no-untyped-def]
def get_response(self) -> bool:
return self.guardrails_intervened
@pytest.fixture(autouse=True)
def bedrock_runtime_client(): # type: ignore[no-untyped-def]
def bedrock_runtime_client() -> "BaseClient":
import boto3
try:
@ -56,7 +59,7 @@ def bedrock_runtime_client(): # type: ignore[no-untyped-def]
@pytest.fixture(autouse=True)
def bedrock_client(): # type: ignore[no-untyped-def]
def bedrock_client() -> "BaseClient":
import boto3
try:
@ -70,7 +73,7 @@ def bedrock_client(): # type: ignore[no-untyped-def]
@pytest.fixture
def bedrock_models(bedrock_client): # type: ignore[no-untyped-def]
def bedrock_models(bedrock_client: "BaseClient") -> dict:
"""List bedrock models."""
response = bedrock_client.list_foundation_models().get("modelSummaries")
models = {}
@ -79,7 +82,9 @@ def bedrock_models(bedrock_client): # type: ignore[no-untyped-def]
return models
def test_claude_instant_v1(bedrock_runtime_client, bedrock_models): # type: ignore[no-untyped-def]
def test_claude_instant_v1(
bedrock_runtime_client: "BaseClient", bedrock_models: dict
) -> None:
try:
llm = Bedrock(
model_id="anthropic.claude-instant-v1",
@ -92,9 +97,9 @@ def test_claude_instant_v1(bedrock_runtime_client, bedrock_models): # type: ign
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
def test_amazon_bedrock_guardrails_no_intervention_for_valid_query( # type: ignore[no-untyped-def]
bedrock_runtime_client, bedrock_models
):
def test_amazon_bedrock_guardrails_no_intervention_for_valid_query(
bedrock_runtime_client: "BaseClient", bedrock_models: dict
) -> None:
try:
llm = Bedrock(
model_id="anthropic.claude-instant-v1",
@ -112,9 +117,9 @@ def test_amazon_bedrock_guardrails_no_intervention_for_valid_query( # type: ign
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
def test_amazon_bedrock_guardrails_intervention_for_invalid_query( # type: ignore[no-untyped-def]
bedrock_runtime_client, bedrock_models
):
def test_amazon_bedrock_guardrails_intervention_for_invalid_query(
bedrock_runtime_client: "BaseClient", bedrock_models: dict
) -> None:
try:
handler = BedrockAsyncCallbackHandler()
llm = Bedrock(

View File

@ -1,13 +1,13 @@
"""Unit test for Google Trends API Wrapper."""
import os
from unittest.mock import patch
from unittest.mock import MagicMock, patch
from langchain_community.utilities.google_trends import GoogleTrendsAPIWrapper
@patch("serpapi.SerpApiClient.get_json")
def test_unexpected_response(mocked_serpapiclient): # type: ignore[no-untyped-def]
def test_unexpected_response(mocked_serpapiclient: MagicMock) -> None:
os.environ["SERPAPI_API_KEY"] = "123abcd"
resp = {
"search_metadata": {

View File

@ -15,7 +15,7 @@ def qdrant_is_not_running() -> bool:
return True
def assert_documents_equals(actual: List[Document], expected: List[Document]): # type: ignore[no-untyped-def]
def assert_documents_equals(actual: List[Document], expected: List[Document]) -> None:
assert len(actual) == len(expected)
for actual_doc, expected_doc in zip(actual, expected):

View File

@ -1,5 +1,7 @@
"""Test Deep Lake functionality."""
from collections.abc import Iterator
import pytest
from langchain_core.documents import Document
from pytest import FixtureRequest
@ -9,7 +11,7 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
@pytest.fixture
def deeplake_datastore() -> DeepLake: # type: ignore[misc]
def deeplake_datastore() -> Iterator[DeepLake]:
texts = ["foo", "bar", "baz"]
metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = DeepLake.from_texts(
@ -53,7 +55,7 @@ def test_deeplake_with_metadatas() -> None:
assert output == [Document(page_content="foo", metadata={"page": "0"})]
def test_deeplake_with_persistence(deeplake_datastore) -> None: # type: ignore[no-untyped-def]
def test_deeplake_with_persistence(deeplake_datastore: DeepLake) -> None:
"""Test end to end construction and search, with persistence."""
output = deeplake_datastore.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": "0"})]
@ -73,7 +75,7 @@ def test_deeplake_with_persistence(deeplake_datastore) -> None: # type: ignore[
# Or on program exit
def test_deeplake_overwrite_flag(deeplake_datastore) -> None: # type: ignore[no-untyped-def]
def test_deeplake_overwrite_flag(deeplake_datastore: DeepLake) -> None:
"""Test overwrite behavior"""
dataset_path = deeplake_datastore.vectorstore.dataset_handler.path
@ -109,7 +111,7 @@ def test_deeplake_overwrite_flag(deeplake_datastore) -> None: # type: ignore[no
output = docsearch.similarity_search("foo", k=1)
def test_similarity_search(deeplake_datastore) -> None: # type: ignore[no-untyped-def]
def test_similarity_search(deeplake_datastore: DeepLake) -> None:
"""Test similarity search."""
distance_metric = "cos"
output = deeplake_datastore.similarity_search(

View File

@ -2,6 +2,7 @@
import os
import random
from types import ModuleType
from typing import Any, Dict, List
import numpy as np
@ -29,7 +30,6 @@ TYPE_4B_FILTERING_TEST_CASES = [
),
]
try:
from hdbcli import dbapi
@ -56,15 +56,15 @@ embedding = NormalizedFakeEmbeddings()
class ConfigData:
def __init__(self): # type: ignore[no-untyped-def]
self.conn = None
self.schema_name = ""
def __init__(self) -> None:
self.conn: dbapi.Connection = None
self.schema_name: str = ""
test_setup = ConfigData()
def generateSchemaName(cursor): # type: ignore[no-untyped-def]
def generateSchemaName(cursor: "dbapi.Cursor") -> str:
# return "Langchain"
cursor.execute(
"SELECT REPLACE(CURRENT_UTCDATE, '-', '') || '_' || BINTOHEX(SYSUUID) FROM "
@ -78,7 +78,7 @@ def generateSchemaName(cursor): # type: ignore[no-untyped-def]
return f"VEC_{uid}"
def setup_module(module): # type: ignore[no-untyped-def]
def setup_module(module: ModuleType) -> None:
test_setup.conn = dbapi.connect(
address=os.environ.get("HANA_DB_ADDRESS"),
port=os.environ.get("HANA_DB_PORT"),
@ -101,7 +101,7 @@ def setup_module(module): # type: ignore[no-untyped-def]
cur.close()
def teardown_module(module): # type: ignore[no-untyped-def]
def teardown_module(module: ModuleType) -> None:
# return
try:
cur = test_setup.conn.cursor()
@ -119,17 +119,17 @@ def texts() -> List[str]:
@pytest.fixture
def metadatas() -> List[str]:
def metadatas() -> List[dict[str, Any]]:
return [
{"start": 0, "end": 100, "quality": "good", "ready": True}, # type: ignore[list-item]
{"start": 100, "end": 200, "quality": "bad", "ready": False}, # type: ignore[list-item]
{"start": 200, "end": 300, "quality": "ugly", "ready": True}, # type: ignore[list-item]
{"start": 200, "quality": "ugly", "ready": True, "Owner": "Steve"}, # type: ignore[list-item]
{"start": 300, "quality": "ugly", "Owner": "Steve"}, # type: ignore[list-item]
{"start": 0, "end": 100, "quality": "good", "ready": True},
{"start": 100, "end": 200, "quality": "bad", "ready": False},
{"start": 200, "end": 300, "quality": "ugly", "ready": True},
{"start": 200, "quality": "ugly", "ready": True, "Owner": "Steve"},
{"start": 300, "quality": "ugly", "Owner": "Steve"},
]
def drop_table(connection, table_name): # type: ignore[no-untyped-def]
def drop_table(connection: "dbapi.Connection", table_name: str) -> None:
try:
cur = connection.cursor()
sql_str = f"DROP TABLE {table_name}"
@ -279,7 +279,7 @@ def test_hanavector_non_existing_table_fixed_vector_length() -> None:
assert vectordb._table_exists(table_name)
vectordb._check_column(
table_name, vector_column, "REAL_VECTOR", vector_column_length
table_name, vector_column, ["REAL_VECTOR"], vector_column_length
)

View File

@ -32,7 +32,7 @@ def fix_distance_precision(
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
"""Fake embeddings functionality for testing."""
def __init__(self): # type: ignore[no-untyped-def]
def __init__(self) -> None:
super(FakeEmbeddingsWithAdaDimension, self).__init__(size=ADA_TOKEN_COUNT)
def embed_documents(self, texts: List[str]) -> List[List[float]]:

View File

@ -1,13 +1,15 @@
import os
import shutil
from collections.abc import Iterator
import pytest
from langchain_core.documents import Document
from langchain_community.vectorstores import NeuralDBVectorStore
@pytest.fixture(scope="session")
def test_csv(): # type: ignore[no-untyped-def]
def test_csv() -> Iterator[str]:
csv = "thirdai-test.csv"
with open(csv, "w") as o:
o.write("column_1,column_2\n")
@ -16,13 +18,13 @@ def test_csv(): # type: ignore[no-untyped-def]
os.remove(csv)
def assert_result_correctness(documents): # type: ignore[no-untyped-def]
def assert_result_correctness(documents: list[Document]) -> None:
assert len(documents) == 1
assert documents[0].page_content == "column_1: column one\n\ncolumn_2: column two"
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_scratch(test_csv): # type: ignore[no-untyped-def]
def test_neuraldb_retriever_from_scratch(test_csv: str) -> None:
retriever = NeuralDBVectorStore.from_scratch()
retriever.insert([test_csv])
documents = retriever.similarity_search("column")
@ -30,7 +32,7 @@ def test_neuraldb_retriever_from_scratch(test_csv): # type: ignore[no-untyped-d
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_checkpoint(test_csv): # type: ignore[no-untyped-def]
def test_neuraldb_retriever_from_checkpoint(test_csv: str) -> None:
checkpoint = "thirdai-test-save.ndb"
if os.path.exists(checkpoint):
shutil.rmtree(checkpoint)
@ -47,7 +49,7 @@ def test_neuraldb_retriever_from_checkpoint(test_csv): # type: ignore[no-untype
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_other_methods(test_csv): # type: ignore[no-untyped-def]
def test_neuraldb_retriever_other_methods(test_csv: str) -> None:
retriever = NeuralDBVectorStore.from_scratch()
retriever.insert([test_csv])
# Make sure they don't throw an error.

View File

@ -268,7 +268,7 @@ def vectara3() -> Iterable[Vectara]:
vectara3.delete(doc_ids)
def test_vectara_with_langchain_mmr(vectara3: Vectara) -> None: # type: ignore[no-untyped-def]
def test_vectara_with_langchain_mmr(vectara3: Vectara) -> None:
# test max marginal relevance
output1 = vectara3.max_marginal_relevance_search(
"generative AI",
@ -299,7 +299,7 @@ def test_vectara_with_langchain_mmr(vectara3: Vectara) -> None: # type: ignore[
)
def test_vectara_rerankers(vectara3: Vectara) -> None: # type: ignore[no-untyped-def]
def test_vectara_rerankers(vectara3: Vectara) -> None:
# test Vectara multi-lingual reranker
summary_config = SummaryConfig(is_enabled=True, max_results=7, response_lang="eng")
rerank_config = RerankConfig(reranker="rerank_multilingual_v1", rerank_k=50)
@ -375,7 +375,7 @@ def test_vectara_rerankers(vectara3: Vectara) -> None: # type: ignore[no-untype
assert len(output2) > 0
def test_vectara_with_summary(vectara3) -> None: # type: ignore[no-untyped-def]
def test_vectara_with_summary(vectara3: Vectara) -> None:
"""Test vectara summary."""
# test summarization
num_results = 10

View File

@ -43,7 +43,7 @@ def test_edenai_messages_formatting(messages: List[BaseMessage], expected: str)
("role", "role_response"),
[("ai", "assistant"), ("human", "user"), ("chat", "user")],
)
def test_edenai_message_role(role: str, role_response) -> None: # type: ignore[no-untyped-def]
def test_edenai_message_role(role: str, role_response: str) -> None:
role = _message_role(role)
assert role == role_response

View File

@ -2,7 +2,7 @@
import json
import os
from typing import Any, AsyncGenerator, Generator, cast
from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, cast
from unittest.mock import patch
import pytest
@ -20,6 +20,9 @@ from langchain_community.chat_models.naver import (
_convert_naver_chat_message_to_message,
)
if TYPE_CHECKING:
from httpx_sse import ServerSentEvent
os.environ["NCP_CLOVASTUDIO_API_KEY"] = "test_api_key"
os.environ["NCP_APIGW_API_KEY"] = "test_gw_key"
@ -131,7 +134,7 @@ async def test_naver_ainvoke(mock_chat_completion_response: dict) -> None:
assert completed
def _make_completion_response_from_token(token: str): # type: ignore[no-untyped-def]
def _make_completion_response_from_token(token: str) -> "ServerSentEvent":
from httpx_sse import ServerSentEvent
return ServerSentEvent(

View File

@ -1,5 +1,6 @@
"""Test OCI Generative AI LLM service"""
from typing import Any
from unittest.mock import MagicMock
import pytest
@ -10,7 +11,7 @@ from langchain_community.chat_models.oci_generative_ai import ChatOCIGenAI
class MockResponseDict(dict):
def __getattr__(self, val): # type: ignore[no-untyped-def]
def __getattr__(self, val: str) -> Any:
return self[val]
@ -29,7 +30,7 @@ def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
provider = model_id.split(".")[0].lower()
def mocked_response(*args): # type: ignore[no-untyped-def]
def mocked_response(*args: Any) -> MockResponseDict:
response_text = "Assistant chat reply."
response = None
if provider == "cohere":
@ -102,6 +103,8 @@ def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
),
}
)
else:
raise ValueError(f"Unsupported provider {provider}")
return response
monkeypatch.setattr(llm.client, "chat", mocked_response)

View File

@ -20,7 +20,7 @@ class MockEmbed4All(MagicMock):
n_threads: Optional[int] = None,
device: Optional[str] = None,
**kwargs: Any,
): # type: ignore[no-untyped-def]
):
assert model_name == _GPT4ALL_MODEL_NAME

View File

@ -45,14 +45,14 @@ class GradientEmbeddingsModel(MagicMock):
output.embeddings = embeddings
return output
async def aembed(self, *args) -> Any: # type: ignore[no-untyped-def]
async def aembed(self, *args: Any) -> Any:
return self.embed(*args)
class MockGradient(MagicMock):
"""Mock Gradient package."""
def __init__(self, access_token: str, workspace_id, host): # type: ignore[no-untyped-def]
def __init__(self, access_token: str, workspace_id: str, host: str) -> None:
assert access_token == _GRADIENT_SECRET
assert workspace_id == _GRADIENT_WORKSPACE_ID
assert host == _GRADIENT_BASE_URL

View File

@ -1,4 +1,5 @@
import json
from typing import Any
import numpy as np
import requests
@ -23,7 +24,9 @@ def test_embed_documents(monkeypatch: MonkeyPatch) -> None:
base_url="http://llamafile-host:8080",
)
def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def]
def mock_post(
url: str, headers: dict[str, str], json: dict[str, Any], timeout: float
) -> requests.Response:
assert url == "http://llamafile-host:8080/embedding"
assert headers == {
"Content-Type": "application/json",
@ -50,7 +53,9 @@ def test_embed_query(monkeypatch: MonkeyPatch) -> None:
base_url="http://llamafile-host:8080",
)
def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def]
def mock_post(
url: str, headers: dict[str, str], json: dict[str, Any], timeout: float
) -> None:
assert url == "http://llamafile-host:8080/embedding"
assert headers == {
"Content-Type": "application/json",

View File

@ -1,5 +1,6 @@
"""Test OCI Generative AI embedding service."""
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock
import pytest
@ -7,9 +8,12 @@ from pytest import MonkeyPatch
from langchain_community.embeddings import OCIGenAIEmbeddings
if TYPE_CHECKING:
from oci.generative_ai_inference.models import EmbedTextDetails
class MockResponseDict(dict):
def __getattr__(self, val): # type: ignore[no-untyped-def]
def __getattr__(self, val: str) -> Any:
return self[val]
@ -26,7 +30,7 @@ def test_embedding_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
client=oci_gen_ai_client,
)
def mocked_response(invocation_obj): # type: ignore[no-untyped-def]
def mocked_response(invocation_obj: "EmbedTextDetails") -> MockResponseDict:
docs = invocation_obj.inputs
embeddings = []

View File

@ -1,14 +1,14 @@
from langchain_community.graphs.neo4j_graph import value_sanitize
def test_value_sanitize_with_small_list(): # type: ignore[no-untyped-def]
def test_value_sanitize_with_small_list() -> None:
small_list = list(range(15)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "small_list": small_list}
expected_output = {"key1": "value1", "small_list": small_list}
assert value_sanitize(input_dict) == expected_output
def test_value_sanitize_with_oversized_list(): # type: ignore[no-untyped-def]
def test_value_sanitize_with_oversized_list() -> None:
oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": oversized_list}
expected_output = {
@ -18,21 +18,21 @@ def test_value_sanitize_with_oversized_list(): # type: ignore[no-untyped-def]
assert value_sanitize(input_dict) == expected_output
def test_value_sanitize_with_nested_oversized_list(): # type: ignore[no-untyped-def]
def test_value_sanitize_with_nested_oversized_list() -> None:
oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": {"key": oversized_list}}
expected_output = {"key1": "value1", "oversized_list": {}}
assert value_sanitize(input_dict) == expected_output
def test_value_sanitize_with_dict_in_list(): # type: ignore[no-untyped-def]
def test_value_sanitize_with_dict_in_list() -> None:
oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": [1, 2, {"key": oversized_list}]}
expected_output = {"key1": "value1", "oversized_list": [1, 2, {}]}
assert value_sanitize(input_dict) == expected_output
def test_value_sanitize_with_dict_in_nested_list(): # type: ignore[no-untyped-def]
def test_value_sanitize_with_dict_in_nested_list() -> None:
input_dict = {
"key1": "value1",
"deeply_nested_lists": [[[[{"final_nested_key": list(range(200))}]]]],

View File

@ -1,6 +1,6 @@
import json
from collections import deque
from typing import Any, Dict
from typing import Any, Dict, Optional
import pytest
import requests
@ -39,7 +39,7 @@ def mock_response() -> requests.Response:
return response
def mock_response_stream(): # type: ignore[no-untyped-def]
def mock_response_stream() -> requests.Response:
mock_response = deque(
[
b'data: {"content":"the","multimodal":false,"slot_id":0,"stop":false}\n\n',
@ -48,7 +48,7 @@ def mock_response_stream(): # type: ignore[no-untyped-def]
)
class MockRaw:
def read(self, chunk_size): # type: ignore[no-untyped-def]
def read(self, chunk_size: int) -> Optional[bytes]:
try:
return mock_response.popleft()
except IndexError:
@ -68,7 +68,13 @@ def test_call(monkeypatch: MonkeyPatch) -> None:
base_url="http://llamafile-host:8080",
)
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
) -> requests.Response:
assert url == "http://llamafile-host:8080/completion"
assert headers == {
"Content-Type": "application/json",
@ -94,7 +100,13 @@ def test_call_with_kwargs(monkeypatch: MonkeyPatch) -> None:
base_url="http://llamafile-host:8080",
)
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
) -> requests.Response:
assert url == "http://llamafile-host:8080/completion"
assert headers == {
"Content-Type": "application/json",
@ -138,7 +150,13 @@ def test_streaming(monkeypatch: MonkeyPatch) -> None:
streaming=True,
)
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
) -> requests.Response:
assert url == "http://llamafile-hostname:8080/completion"
assert headers == {
"Content-Type": "application/json",

View File

@ -1,5 +1,6 @@
"""Test OCI Generative AI LLM service"""
from typing import Any
from unittest.mock import MagicMock
import pytest
@ -9,7 +10,7 @@ from langchain_community.llms.oci_generative_ai import OCIGenAI
class MockResponseDict(dict):
def __getattr__(self, val): # type: ignore[no-untyped-def]
def __getattr__(self, val: Any) -> Any:
return self[val]
@ -29,7 +30,7 @@ def test_llm_complete(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
provider = model_id.split(".")[0].lower()
def mocked_response(*args): # type: ignore[no-untyped-def]
def mocked_response(*args: Any) -> MockResponseDict:
response_text = "This is the completion."
if provider == "cohere":
@ -75,6 +76,7 @@ def test_llm_complete(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
),
}
)
raise ValueError("Unsupported provider")
monkeypatch.setattr(llm.client, "generate_text", mocked_response)
output = llm.invoke("This is a prompt.", temperature=0.2)

View File

@ -1,14 +1,16 @@
from typing import Any, Optional
import requests
from pytest import MonkeyPatch
from langchain_community.llms.ollama import Ollama
def mock_response_stream(): # type: ignore[no-untyped-def]
def mock_response_stream() -> requests.Response:
mock_response = [b'{ "response": "Response chunk 1" }']
class MockRaw:
def read(self, chunk_size): # type: ignore[no-untyped-def]
def read(self, chunk_size: int) -> Optional[bytes]:
try:
return mock_response.pop()
except IndexError:
@ -31,7 +33,14 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
timeout=300,
)
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
auth: tuple[str, str],
) -> requests.Response:
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
@ -57,7 +66,14 @@ def test_pass_auth_if_provided(monkeypatch: MonkeyPatch) -> None:
timeout=300,
)
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
auth: tuple[str, str],
) -> requests.Response:
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
@ -77,7 +93,14 @@ def test_pass_auth_if_provided(monkeypatch: MonkeyPatch) -> None:
def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
auth: tuple[str, str],
) -> requests.Response:
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
@ -97,7 +120,14 @@ def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None:
"""Test that top level params are sent to the endpoint as top level params"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
auth: tuple[str, str],
) -> requests.Response:
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
@ -145,7 +175,14 @@ def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None:
"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
auth: tuple[str, str],
) -> requests.Response:
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
@ -194,7 +231,14 @@ def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None:
"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
auth: tuple[str, str],
) -> requests.Response:
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",

View File

@ -1,5 +1,6 @@
"""Test building the Zapier tool, not running it."""
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
import pytest
@ -9,6 +10,9 @@ from langchain_community.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT
from langchain_community.tools.zapier.tool import ZapierNLARunAction
from langchain_community.utilities.zapier import ZapierNLAWrapper
if TYPE_CHECKING:
from pytest_mock import MockerFixture
def test_default_base_prompt() -> None:
"""Test that the default prompt is being inserted."""
@ -136,7 +140,7 @@ def test_create_action_payload_with_params() -> None:
assert payload["test"] == "test"
async def test_apreview(mocker) -> None: # type: ignore[no-untyped-def]
async def test_apreview(mocker: "MockerFixture") -> None:
"""Test that the action payload with params is being created correctly."""
tool = ZapierNLARunAction(
action_id="test",
@ -164,7 +168,7 @@ async def test_apreview(mocker) -> None: # type: ignore[no-untyped-def]
)
async def test_arun(mocker) -> None: # type: ignore[no-untyped-def]
async def test_arun(mocker: "MockerFixture") -> None:
"""Test that the action payload with params is being created correctly."""
tool = ZapierNLARunAction(
action_id="test",
@ -188,7 +192,7 @@ async def test_arun(mocker) -> None: # type: ignore[no-untyped-def]
)
async def test_alist(mocker) -> None: # type: ignore[no-untyped-def]
async def test_alist(mocker: "MockerFixture") -> None:
"""Test that the action payload with params is being created correctly."""
tool = ZapierNLARunAction(
action_id="test",

View File

@ -1,5 +1,5 @@
import json
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from unittest.mock import patch
import pytest
@ -9,6 +9,9 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
DEFAULT_VECTOR_DIMENSION = 4
if TYPE_CHECKING:
from azure.search.documents.indexes.models import SearchIndex
class FakeEmbeddingsWithDimension(FakeEmbeddings):
"""Fake embeddings functionality for testing."""
@ -36,7 +39,7 @@ DEFAULT_ACCESS_TOKEN = "myaccesstoken1"
DEFAULT_EMBEDDING_MODEL = FakeEmbeddingsWithDimension()
def mock_default_index(*args, **kwargs): # type: ignore[no-untyped-def]
def mock_default_index(*args: Any, **kwargs: Any) -> "SearchIndex":
from azure.search.documents.indexes.models import (
ExhaustiveKnnAlgorithmConfiguration,
ExhaustiveKnnParameters,
@ -155,12 +158,12 @@ def test_init_new_index() -> None:
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import SearchIndex
def no_index(self, name: str): # type: ignore[no-untyped-def]
def no_index(self: SearchIndexClient, name: str) -> SearchIndex:
raise ResourceNotFoundError
created_index: Optional[SearchIndex] = None
def mock_create_index(self, index): # type: ignore[no-untyped-def]
def mock_create_index(self: SearchIndexClient, index: SearchIndex) -> None:
nonlocal created_index
created_index = index
@ -203,7 +206,9 @@ def test_ids_used_correctly() -> None:
def __init__(self) -> None:
self.succeeded: bool = True
def mock_upload_documents(self, documents: List[object]) -> List[Response]: # type: ignore[no-untyped-def]
def mock_upload_documents(
self: SearchClient, documents: List[object]
) -> List[Response]:
# assume all documents uploaded successfuly
response = [Response() for _ in documents]
return response