community: Remove no-untyped-def escapes

This commit is contained in:
cbornet 2025-04-16 14:55:33 +02:00
parent cf2697ec53
commit a60f82b1e2
42 changed files with 424 additions and 188 deletions

View File

@ -575,8 +575,8 @@ class ChatBaichuan(BaseChatModel):
) )
return res return res
def _create_payload_parameters( # type: ignore[no-untyped-def] def _create_payload_parameters(
self, messages: List[BaseMessage], **kwargs self, messages: List[BaseMessage], **kwargs: Any
) -> Dict[str, Any]: ) -> Dict[str, Any]:
parameters = {**self._default_params, **kwargs} parameters = {**self._default_params, **kwargs}
temperature = parameters.pop("temperature", 0.3) temperature = parameters.pop("temperature", 0.3)
@ -600,7 +600,7 @@ class ChatBaichuan(BaseChatModel):
return payload return payload
def _create_headers_parameters(self, **kwargs) -> Dict[str, Any]: # type: ignore[no-untyped-def] def _create_headers_parameters(self, **kwargs: Any) -> Dict[str, Any]:
parameters = {**self._default_params, **kwargs} parameters = {**self._default_params, **kwargs}
default_headers = parameters.pop("headers", {}) default_headers = parameters.pop("headers", {})
api_key = "" api_key = ""

View File

@ -439,8 +439,8 @@ class MiniMaxChat(BaseChatModel):
} }
return ChatResult(generations=generations, llm_output=llm_output) return ChatResult(generations=generations, llm_output=llm_output)
def _create_payload_parameters( # type: ignore[no-untyped-def] def _create_payload_parameters(
self, messages: List[BaseMessage], is_stream: bool = False, **kwargs self, messages: List[BaseMessage], is_stream: bool = False, **kwargs: Any
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Create API request body parameters.""" """Create API request body parameters."""
message_dicts = [_convert_message_to_dict(m) for m in messages] message_dicts = [_convert_message_to_dict(m) for m in messages]

View File

@ -1,5 +1,8 @@
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Union from types import TracebackType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing_extensions import Self
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
@ -65,10 +68,15 @@ class CHMParser(object):
self.file = chm.CHMFile() self.file = chm.CHMFile()
self.file.LoadCHM(path) self.file.LoadCHM(path)
def __enter__(self): # type: ignore[no-untyped-def] def __enter__(self) -> Self:
return self return self
def __exit__(self, exc_type, exc_value, traceback): # type: ignore[no-untyped-def] def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if self.file: if self.file:
self.file.CloseCHM() self.file.CloseCHM()

View File

@ -1,6 +1,6 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Iterator, Optional, Sequence, Union from typing import TYPE_CHECKING, Iterator, Optional, Sequence, Union
from langchain_core.documents import Document from langchain_core.documents import Document
@ -8,6 +8,9 @@ from langchain_community.document_loaders.base import BaseLoader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
import mwxml
class MWDumpLoader(BaseLoader): class MWDumpLoader(BaseLoader):
"""Load `MediaWiki` dump from an `XML` file. """Load `MediaWiki` dump from an `XML` file.
@ -60,7 +63,7 @@ class MWDumpLoader(BaseLoader):
self.skip_redirects = skip_redirects self.skip_redirects = skip_redirects
self.stop_on_error = stop_on_error self.stop_on_error = stop_on_error
def _load_dump_file(self): # type: ignore[no-untyped-def] def _load_dump_file(self) -> "mwxml.Dump":
try: try:
import mwxml import mwxml
except ImportError as e: except ImportError as e:
@ -70,7 +73,7 @@ class MWDumpLoader(BaseLoader):
return mwxml.Dump.from_file(open(self.file_path, encoding=self.encoding)) return mwxml.Dump.from_file(open(self.file_path, encoding=self.encoding))
def _load_single_page_from_dump(self, page) -> Document: # type: ignore[no-untyped-def, return] def _load_single_page_from_dump(self, page: "mwxml.Page") -> Document: # type: ignore[return]
"""Parse a single page.""" """Parse a single page."""
try: try:
import mwparserfromhell import mwparserfromhell

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Iterator, List, Optional from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
from langchain_core.documents import Document from langchain_core.documents import Document
@ -34,20 +34,24 @@ class AzureAIDocumentIntelligenceParser(BaseBlobParser):
kwargs = {} kwargs = {}
if api_key is None and azure_credential is None: credential: Union[AzureKeyCredential, TokenCredential]
if azure_credential:
if api_key is not None:
raise ValueError(
"Only one of api_key or azure_credential should be provided."
)
credential = azure_credential
elif api_key is not None:
credential = AzureKeyCredential(api_key)
else:
raise ValueError("Either api_key or azure_credential must be provided.") raise ValueError("Either api_key or azure_credential must be provided.")
if api_key and azure_credential:
raise ValueError(
"Only one of api_key or azure_credential should be provided."
)
if api_version is not None: if api_version is not None:
kwargs["api_version"] = api_version kwargs["api_version"] = api_version
self.client = DocumentIntelligenceClient( self.client = DocumentIntelligenceClient(
endpoint=api_endpoint, endpoint=api_endpoint,
credential=azure_credential or AzureKeyCredential(api_key), credential=credential,
headers={"x-ms-useragent": "langchain-parser/1.0.0"}, headers={"x-ms-useragent": "langchain-parser/1.0.0"},
**kwargs, **kwargs,
) )

View File

@ -169,5 +169,5 @@ class TinyAsyncGradientEmbeddingClient: #: :meta private:
It might be entirely removed in the future. It might be entirely removed in the future.
""" """
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] def __init__(self, *args: Any, **kwargs: Any) -> None:
raise ValueError("Deprecated,TinyAsyncGradientEmbeddingClient was removed.") raise ValueError("Deprecated,TinyAsyncGradientEmbeddingClient was removed.")

View File

@ -1,10 +1,13 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, Iterator, List, Mapping, Optional from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Mapping, Optional
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.utils import pre_init from langchain_core.utils import pre_init
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
if TYPE_CHECKING:
import oci
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
@ -122,12 +125,14 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings):
client_kwargs.pop("signer", None) client_kwargs.pop("signer", None)
elif values["auth_type"] == OCIAuthType(2).name: elif values["auth_type"] == OCIAuthType(2).name:
def make_security_token_signer(oci_config): # type: ignore[no-untyped-def] def make_security_token_signer(
oci_config: dict[str, Any],
) -> "oci.auth.signers.SecurityTokenSigner":
pk = oci.signer.load_private_key_from_file( pk = oci.signer.load_private_key_from_file(
oci_config.get("key_file"), None oci_config.get("key_file"), None
) )
with open( with open(
oci_config.get("security_token_file"), encoding="utf-8" str(oci_config.get("security_token_file")), encoding="utf-8"
) as f: ) as f:
st_string = f.read() st_string = f.read()
return oci.auth.signers.SecurityTokenSigner(st_string, pk) return oci.auth.signers.SecurityTokenSigner(st_string, pk)

View File

@ -159,18 +159,20 @@ def _create_retry_decorator(llm: YandexGPTEmbeddings) -> Callable[[Any], Any]:
) )
def _embed_with_retry(llm: YandexGPTEmbeddings, **kwargs: Any) -> Any: def _embed_with_retry(llm: YandexGPTEmbeddings, **kwargs: Any) -> list[list[float]]:
"""Use tenacity to retry the embedding call.""" """Use tenacity to retry the embedding call."""
retry_decorator = _create_retry_decorator(llm) retry_decorator = _create_retry_decorator(llm)
@retry_decorator @retry_decorator
def _completion_with_retry(**_kwargs: Any) -> Any: def _completion_with_retry(**_kwargs: Any) -> list[list[float]]:
return _make_request(llm, **_kwargs) return _make_request(llm, **_kwargs)
return _completion_with_retry(**kwargs) return _completion_with_retry(**kwargs)
def _make_request(self: YandexGPTEmbeddings, texts: List[str], **kwargs): # type: ignore[no-untyped-def] def _make_request(
self: YandexGPTEmbeddings, texts: List[str], **kwargs: Any
) -> list[list[float]]:
try: try:
import grpc import grpc

View File

@ -93,6 +93,7 @@ class OntotextGraphDBGraph:
self.graph = rdflib.Graph(store, identifier=None, bind_namespaces="none") self.graph = rdflib.Graph(store, identifier=None, bind_namespaces="none")
self._check_connectivity() self._check_connectivity()
ontology_schema_graph: "rdflib.Graph"
if local_file: if local_file:
ontology_schema_graph = self._load_ontology_schema_from_file( ontology_schema_graph = self._load_ontology_schema_from_file(
local_file, local_file,
@ -140,7 +141,9 @@ class OntotextGraphDBGraph:
) )
@staticmethod @staticmethod
def _load_ontology_schema_from_file(local_file: str, local_file_format: str = None): # type: ignore[no-untyped-def, assignment] def _load_ontology_schema_from_file(
local_file: str, local_file_format: Optional[str] = None
) -> "rdflib.ConjunctiveGraph": # type: ignore[assignment]
""" """
Parse the ontology schema statements from the provided file Parse the ontology schema statements from the provided file
""" """
@ -177,7 +180,7 @@ class OntotextGraphDBGraph:
"Invalid query type. Only CONSTRUCT queries are supported." "Invalid query type. Only CONSTRUCT queries are supported."
) )
def _load_ontology_schema_with_query(self, query: str): # type: ignore[no-untyped-def] def _load_ontology_schema_with_query(self, query: str) -> "rdflib.Graph":
""" """
Execute the query for collecting the ontology schema statements Execute the query for collecting the ontology schema statements
""" """
@ -188,6 +191,9 @@ class OntotextGraphDBGraph:
except ParserError as e: except ParserError as e:
raise ValueError(f"Generated SPARQL statement is invalid\n{e}") raise ValueError(f"Generated SPARQL statement is invalid\n{e}")
if not results.graph:
raise ValueError("Missing graph in results.")
return results.graph return results.graph
@property @property

View File

@ -77,7 +77,7 @@ class TigerGraph(GraphStore):
""" """
return self._conn.getSchema(force=True) return self._conn.getSchema(force=True)
def refresh_schema(self): # type: ignore[no-untyped-def] def refresh_schema(self) -> None:
self.generate_schema() self.generate_schema()
def query(self, query: str) -> Dict[str, Any]: # type: ignore[override] def query(self, query: str) -> Dict[str, Any]: # type: ignore[override]

View File

@ -136,12 +136,14 @@ class OCIGenAIBase(BaseModel, ABC):
client_kwargs.pop("signer", None) client_kwargs.pop("signer", None)
elif values["auth_type"] == OCIAuthType(2).name: elif values["auth_type"] == OCIAuthType(2).name:
def make_security_token_signer(oci_config): # type: ignore[no-untyped-def] def make_security_token_signer(
oci_config: dict[str, Any],
) -> "oci.auth.signers.SecurityTokenSigner":
pk = oci.signer.load_private_key_from_file( pk = oci.signer.load_private_key_from_file(
oci_config.get("key_file"), None oci_config.get("key_file"), None
) )
with open( with open(
oci_config.get("security_token_file"), encoding="utf-8" str(oci_config.get("security_token_file")), encoding="utf-8"
) as f: ) as f:
st_string = f.read() st_string = f.read()
return oci.auth.signers.SecurityTokenSigner(st_string, pk) return oci.auth.signers.SecurityTokenSigner(st_string, pk)

View File

@ -77,8 +77,11 @@ class ZapierNLAWrapper(BaseModel):
response.raise_for_status() response.raise_for_status()
return await response.json() return await response.json()
def _create_action_payload( # type: ignore[no-untyped-def] def _create_action_payload(
self, instructions: str, params: Optional[Dict] = None, preview_only=False self,
instructions: str,
params: Optional[Dict] = None,
preview_only: bool = False,
) -> Dict: ) -> Dict:
"""Create a payload for an action.""" """Create a payload for an action."""
data = params if params else {} data = params if params else {}
@ -95,12 +98,12 @@ class ZapierNLAWrapper(BaseModel):
"""Create a url for an action.""" """Create a url for an action."""
return self.zapier_nla_api_base + f"exposed/{action_id}/execute/" return self.zapier_nla_api_base + f"exposed/{action_id}/execute/"
def _create_action_request( # type: ignore[no-untyped-def] def _create_action_request(
self, self,
action_id: str, action_id: str,
instructions: str, instructions: str,
params: Optional[Dict] = None, params: Optional[Dict] = None,
preview_only=False, preview_only: bool = False,
) -> Request: ) -> Request:
data = self._create_action_payload(instructions, params, preview_only) data = self._create_action_payload(instructions, params, preview_only)
return Request( return Request(
@ -259,39 +262,37 @@ class ZapierNLAWrapper(BaseModel):
) )
return response["result"] return response["result"]
def run_as_str(self, *args, **kwargs) -> str: # type: ignore[no-untyped-def] def run_as_str(self, *args: Any, **kwargs: Any) -> str:
"""Same as run, but returns a stringified version of the JSON for """Same as run, but returns a stringified version of the JSON for
insertting back into an LLM.""" insertting back into an LLM."""
data = self.run(*args, **kwargs) data = self.run(*args, **kwargs)
return json.dumps(data) return json.dumps(data)
async def arun_as_str(self, *args, **kwargs) -> str: # type: ignore[no-untyped-def] async def arun_as_str(self, *args: Any, **kwargs: Any) -> str:
"""Same as run, but returns a stringified version of the JSON for """Same as run, but returns a stringified version of the JSON for
insertting back into an LLM.""" insertting back into an LLM."""
data = await self.arun(*args, **kwargs) data = await self.arun(*args, **kwargs)
return json.dumps(data) return json.dumps(data)
def preview_as_str(self, *args, **kwargs) -> str: # type: ignore[no-untyped-def] def preview_as_str(self, *args: Any, **kwargs: Any) -> str:
"""Same as preview, but returns a stringified version of the JSON for """Same as preview, but returns a stringified version of the JSON for
insertting back into an LLM.""" insertting back into an LLM."""
data = self.preview(*args, **kwargs) data = self.preview(*args, **kwargs)
return json.dumps(data) return json.dumps(data)
async def apreview_as_str( # type: ignore[no-untyped-def] async def apreview_as_str(self, *args: Any, **kwargs: Any) -> str:
self, *args, **kwargs
) -> str:
"""Same as preview, but returns a stringified version of the JSON for """Same as preview, but returns a stringified version of the JSON for
insertting back into an LLM.""" insertting back into an LLM."""
data = await self.apreview(*args, **kwargs) data = await self.apreview(*args, **kwargs)
return json.dumps(data) return json.dumps(data)
def list_as_str(self) -> str: # type: ignore[no-untyped-def] def list_as_str(self) -> str:
"""Same as list, but returns a stringified version of the JSON for """Same as list, but returns a stringified version of the JSON for
insertting back into an LLM.""" insertting back into an LLM."""
actions = self.list() actions = self.list()
return json.dumps(actions) return json.dumps(actions)
async def alist_as_str(self) -> str: # type: ignore[no-untyped-def] async def alist_as_str(self) -> str:
"""Same as list, but returns a stringified version of the JSON for """Same as list, but returns a stringified version of the JSON for
insertting back into an LLM.""" insertting back into an LLM."""
actions = await self.alist() actions = await self.alist()

View File

@ -7,6 +7,7 @@ import json
import logging import logging
import time import time
import uuid import uuid
from types import TracebackType
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -22,6 +23,7 @@ from typing import (
Type, Type,
Union, Union,
cast, cast,
overload,
) )
import numpy as np import numpy as np
@ -80,6 +82,54 @@ FIELDS_METADATA = get_from_env(
MAX_UPLOAD_BATCH_SIZE = 1000 MAX_UPLOAD_BATCH_SIZE = 1000
@overload
def _get_search_client(
endpoint: str,
index_name: str,
key: Optional[str] = None,
azure_ad_access_token: Optional[str] = None,
semantic_configuration_name: Optional[str] = None,
fields: Optional[List[SearchField]] = None,
vector_search: Optional[VectorSearch] = None,
semantic_configurations: Optional[
Union[SemanticConfiguration, List[SemanticConfiguration]]
] = None,
scoring_profiles: Optional[List[ScoringProfile]] = None,
default_scoring_profile: Optional[str] = None,
default_fields: Optional[List[SearchField]] = None,
user_agent: Optional[str] = "langchain-comm-python-azure-search",
cors_options: Optional[CorsOptions] = None,
async_: Literal[False] = False,
additional_search_client_options: Optional[Dict[str, Any]] = None,
azure_credential: Optional[TokenCredential] = None,
azure_async_credential: Optional[AsyncTokenCredential] = None,
) -> Union[SearchClient]: ...
@overload
def _get_search_client(
endpoint: str,
index_name: str,
key: Optional[str] = None,
azure_ad_access_token: Optional[str] = None,
semantic_configuration_name: Optional[str] = None,
fields: Optional[List[SearchField]] = None,
vector_search: Optional[VectorSearch] = None,
semantic_configurations: Optional[
Union[SemanticConfiguration, List[SemanticConfiguration]]
] = None,
scoring_profiles: Optional[List[ScoringProfile]] = None,
default_scoring_profile: Optional[str] = None,
default_fields: Optional[List[SearchField]] = None,
user_agent: Optional[str] = "langchain-comm-python-azure-search",
cors_options: Optional[CorsOptions] = None,
async_: Literal[True] = True,
additional_search_client_options: Optional[Dict[str, Any]] = None,
azure_credential: Optional[TokenCredential] = None,
azure_async_credential: Optional[AsyncTokenCredential] = None,
) -> Union[AsyncSearchClient]: ...
def _get_search_client( def _get_search_client(
endpoint: str, endpoint: str,
index_name: str, index_name: str,
@ -102,6 +152,7 @@ def _get_search_client(
azure_async_credential: Optional[AsyncTokenCredential] = None, azure_async_credential: Optional[AsyncTokenCredential] = None,
) -> Union[SearchClient, AsyncSearchClient]: ) -> Union[SearchClient, AsyncSearchClient]:
from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.exceptions import ResourceNotFoundError from azure.core.exceptions import ResourceNotFoundError
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential
@ -139,22 +190,54 @@ def _get_search_client(
) -> AccessToken: ) -> AccessToken:
return self._token return self._token
class AsyncTokenCredentialWrapper(AsyncTokenCredential):
def __init__(self, credential: TokenCredential):
self._credential = credential
async def get_token(
self,
*scopes: str,
claims: Optional[str] = None,
tenant_id: Optional[str] = None,
enable_cae: bool = False,
**kwargs: Any,
) -> AccessToken:
return self._credential.get_token(
*scopes,
claims=claims,
tenant_id=tenant_id,
enable_cae=enable_cae,
**kwargs,
)
async def close(self) -> None:
pass
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
pass
additional_search_client_options = additional_search_client_options or {} additional_search_client_options = additional_search_client_options or {}
default_fields = default_fields or [] default_fields = default_fields or []
credential: Union[AzureKeyCredential, TokenCredential, InteractiveBrowserCredential] credential: Union[AzureKeyCredential, TokenCredential]
async_credential: Union[AzureKeyCredential, AsyncTokenCredential]
# Determine the appropriate credential to use # Determine the appropriate credential to use
if key is not None: if key is not None:
if key.upper() == "INTERACTIVE": if key.upper() == "INTERACTIVE":
credential = InteractiveBrowserCredential() credential = cast("TokenCredential", InteractiveBrowserCredential())
credential.get_token("https://search.azure.com/.default") credential.get_token("https://search.azure.com/.default")
async_credential = credential async_credential = AsyncTokenCredentialWrapper(credential)
else: else:
credential = AzureKeyCredential(key) credential = AzureKeyCredential(key)
async_credential = credential async_credential = credential
elif azure_ad_access_token is not None: elif azure_ad_access_token is not None:
credential = AzureBearerTokenCredential(azure_ad_access_token) credential = AzureBearerTokenCredential(azure_ad_access_token)
async_credential = credential async_credential = AsyncTokenCredentialWrapper(credential)
else: else:
credential = azure_credential or DefaultAzureCredential() credential = azure_credential or DefaultAzureCredential()
async_credential = azure_async_credential or AsyncDefaultAzureCredential() async_credential = azure_async_credential or AsyncDefaultAzureCredential()
@ -1121,7 +1204,7 @@ class AzureSearch(VectorStore):
search_text=text_query, search_text=text_query,
vector_queries=[ vector_queries=[
VectorizedQuery( VectorizedQuery(
vector=np.array(embedding, dtype=np.float32).tolist(), vector=embedding,
k_nearest_neighbors=k, k_nearest_neighbors=k,
fields=FIELDS_CONTENT_VECTOR, fields=FIELDS_CONTENT_VECTOR,
) )
@ -1157,7 +1240,7 @@ class AzureSearch(VectorStore):
search_text=text_query, search_text=text_query,
vector_queries=[ vector_queries=[
VectorizedQuery( VectorizedQuery(
vector=np.array(embedding, dtype=np.float32).tolist(), vector=embedding,
k_nearest_neighbors=k, k_nearest_neighbors=k,
fields=FIELDS_CONTENT_VECTOR, fields=FIELDS_CONTENT_VECTOR,
) )
@ -1302,7 +1385,7 @@ class AzureSearch(VectorStore):
search_text=query, search_text=query,
vector_queries=[ vector_queries=[
VectorizedQuery( VectorizedQuery(
vector=np.array(self.embed_query(query), dtype=np.float32).tolist(), vector=self.embed_query(query),
k_nearest_neighbors=k, k_nearest_neighbors=k,
fields=FIELDS_CONTENT_VECTOR, fields=FIELDS_CONTENT_VECTOR,
) )
@ -1390,7 +1473,7 @@ class AzureSearch(VectorStore):
search_text=query, search_text=query,
vector_queries=[ vector_queries=[
VectorizedQuery( VectorizedQuery(
vector=np.array(vector, dtype=np.float32).tolist(), vector=vector,
k_nearest_neighbors=k, k_nearest_neighbors=k,
fields=FIELDS_CONTENT_VECTOR, fields=FIELDS_CONTENT_VECTOR,
) )
@ -1754,7 +1837,7 @@ async def _aresults_to_documents(
async def _areorder_results_with_maximal_marginal_relevance( async def _areorder_results_with_maximal_marginal_relevance(
results: SearchItemPaged[Dict], results: AsyncSearchItemPaged[Dict],
query_embedding: np.ndarray, query_embedding: np.ndarray,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
k: int = 4, k: int = 4,

View File

@ -231,7 +231,7 @@ class BigQueryVectorSearch(VectorStore):
self._logger.debug("Vector index already exists.") self._logger.debug("Vector index already exists.")
self._have_index = True self._have_index = True
def _create_index_in_background(self): # type: ignore[no-untyped-def] def _create_index_in_background(self) -> None:
if self._have_index or self._creating_index: if self._have_index or self._creating_index:
# Already have an index or in the process of creating one. # Already have an index or in the process of creating one.
return return
@ -240,7 +240,7 @@ class BigQueryVectorSearch(VectorStore):
thread = Thread(target=self._create_index, daemon=True) thread = Thread(target=self._create_index, daemon=True)
thread.start() thread.start()
def _create_index(self): # type: ignore[no-untyped-def] def _create_index(self) -> None:
from google.api_core.exceptions import ClientError from google.api_core.exceptions import ClientError
table = self.bq_client.get_table(self.vectors_table) table = self.bq_client.get_table(self.vectors_table)

View File

@ -943,7 +943,7 @@ class DeepLake(VectorStore):
return self.vectorstore.dataset return self.vectorstore.dataset
@classmethod @classmethod
def _validate_kwargs(cls, kwargs, method_name): # type: ignore[no-untyped-def] def _validate_kwargs(cls, kwargs: Any, method_name: str) -> None:
if kwargs: if kwargs:
valid_items = cls._get_valid_args(method_name) valid_items = cls._get_valid_args(method_name)
unsupported_items = cls._get_unsupported_items(kwargs, valid_items) unsupported_items = cls._get_unsupported_items(kwargs, valid_items)
@ -955,14 +955,14 @@ class DeepLake(VectorStore):
) )
@classmethod @classmethod
def _get_valid_args(cls, method_name): # type: ignore[no-untyped-def] def _get_valid_args(cls, method_name: str) -> list[str]:
if method_name == "search": if method_name == "search":
return cls._valid_search_kwargs return cls._valid_search_kwargs
else: else:
return [] return []
@staticmethod @staticmethod
def _get_unsupported_items(kwargs, valid_items): # type: ignore[no-untyped-def] def _get_unsupported_items(kwargs: Any, valid_items: list[str]) -> Optional[str]:
kwargs = {k: v for k, v in kwargs.items() if k not in valid_items} kwargs = {k: v for k, v in kwargs.items() if k not in valid_items}
unsupported_items = None unsupported_items = None
if kwargs: if kwargs:

View File

@ -15,7 +15,6 @@ from typing import (
Optional, Optional,
Pattern, Pattern,
Tuple, Tuple,
Type,
) )
import numpy as np import numpy as np
@ -23,6 +22,7 @@ from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor from langchain_core.runnables.config import run_in_executor
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from typing_extensions import Self
from langchain_community.vectorstores.utils import ( from langchain_community.vectorstores.utils import (
DistanceStrategy, DistanceStrategy,
@ -149,7 +149,7 @@ class HanaDB(VectorStore):
for column_name in self.specific_metadata_columns: for column_name in self.specific_metadata_columns:
self._check_column(self.table_name, column_name) self._check_column(self.table_name, column_name)
def _table_exists(self, table_name) -> bool: # type: ignore[no-untyped-def] def _table_exists(self, table_name: str) -> bool:
sql_str = ( sql_str = (
"SELECT COUNT(*) FROM SYS.TABLES WHERE SCHEMA_NAME = CURRENT_SCHEMA" "SELECT COUNT(*) FROM SYS.TABLES WHERE SCHEMA_NAME = CURRENT_SCHEMA"
" AND TABLE_NAME = ?" " AND TABLE_NAME = ?"
@ -165,9 +165,13 @@ class HanaDB(VectorStore):
cur.close() cur.close()
return False return False
def _check_column( # type: ignore[no-untyped-def] def _check_column(
self, table_name, column_name, column_type=None, column_length=None self,
): table_name: str,
column_name: str,
column_type: Optional[list[str]] = None,
column_length: Optional[int] = None,
) -> None:
sql_str = ( sql_str = (
"SELECT DATA_TYPE_NAME, LENGTH FROM SYS.TABLE_COLUMNS WHERE " "SELECT DATA_TYPE_NAME, LENGTH FROM SYS.TABLE_COLUMNS WHERE "
"SCHEMA_NAME = CURRENT_SCHEMA " "SCHEMA_NAME = CURRENT_SCHEMA "
@ -406,8 +410,8 @@ class HanaDB(VectorStore):
return [] return []
@classmethod @classmethod
def from_texts( # type: ignore[no-untyped-def, override] def from_texts( # type: ignore[override]
cls: Type[HanaDB], cls,
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
@ -420,7 +424,7 @@ class HanaDB(VectorStore):
vector_column_length: int = default_vector_column_length, vector_column_length: int = default_vector_column_length,
*, *,
specific_metadata_columns: Optional[List[str]] = None, specific_metadata_columns: Optional[List[str]] = None,
): ) -> Self:
"""Create a HanaDB instance from raw documents. """Create a HanaDB instance from raw documents.
This is a user-friendly interface that: This is a user-friendly interface that:
1. Embeds documents. 1. Embeds documents.
@ -512,12 +516,12 @@ class HanaDB(VectorStore):
) )
order_str = f" order by CS {HANA_DISTANCE_FUNCTION[self.distance_strategy][1]}" order_str = f" order by CS {HANA_DISTANCE_FUNCTION[self.distance_strategy][1]}"
where_str, query_tuple = self._create_where_by_filter(filter) where_str, query_tuple = self._create_where_by_filter(filter)
query_tuple = (embedding_as_str,) + tuple(query_tuple) query_params = (embedding_as_str,) + tuple(query_tuple)
sql_str = sql_str + where_str sql_str = sql_str + where_str
sql_str = sql_str + order_str sql_str = sql_str + order_str
try: try:
cur = self.connection.cursor() cur = self.connection.cursor()
cur.execute(sql_str, query_tuple) cur.execute(sql_str, query_params)
if cur.has_result_set(): if cur.has_result_set():
rows = cur.fetchall() rows = cur.fetchall()
for row in rows: for row in rows:
@ -567,15 +571,15 @@ class HanaDB(VectorStore):
) )
return [doc for doc, _ in docs_and_scores] return [doc for doc, _ in docs_and_scores]
def _create_where_by_filter(self, filter): # type: ignore[no-untyped-def] def _create_where_by_filter(self, filter: Optional[dict]) -> Tuple[str, list[Any]]:
query_tuple = [] query_tuple: list[Any] = []
where_str = "" where_str = ""
if filter: if filter:
where_str, query_tuple = self._process_filter_object(filter) where_str, query_tuple = self._process_filter_object(filter)
where_str = " WHERE " + where_str where_str = " WHERE " + where_str
return where_str, query_tuple return where_str, query_tuple
def _process_filter_object(self, filter): # type: ignore[no-untyped-def] def _process_filter_object(self, filter: Optional[dict]) -> Tuple[str, list[Any]]:
query_tuple = [] query_tuple = []
where_str = "" where_str = ""
if filter: if filter:

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any, Iterable, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
from uuid import uuid4 from uuid import uuid4
import numpy as np import numpy as np
@ -12,6 +12,9 @@ from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance from langchain_community.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
from pymilvus.orm.mutation import MutationResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_MILVUS_CONNECTION = { DEFAULT_MILVUS_CONNECTION = {
@ -938,9 +941,9 @@ class Milvus(VectorStore):
ret.append(documents[x]) ret.append(documents[x])
return ret return ret
def delete( # type: ignore[no-untyped-def] def delete(
self, ids: Optional[List[str]] = None, expr: Optional[str] = None, **kwargs: str self, ids: Optional[List[str]] = None, expr: Optional[str] = None, **kwargs: Any
): ) -> MutationResult:
"""Delete by vector ID or boolean expression. """Delete by vector ID or boolean expression.
Refer to [Milvus documentation](https://milvus.io/docs/delete_data.md) Refer to [Milvus documentation](https://milvus.io/docs/delete_data.md)
for notes and examples of expressions. for notes and examples of expressions.

View File

@ -12,9 +12,11 @@ from typing import (
Generator, Generator,
Iterable, Iterable,
List, List,
Mapping,
Optional, Optional,
Tuple, Tuple,
Type, Type,
Union,
) )
import numpy as np import numpy as np
@ -750,7 +752,9 @@ class PGVector(VectorStore):
else: else:
raise NotImplementedError() raise NotImplementedError()
def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyped-def] def _create_filter_clause_deprecated(
self, key: str, value: dict[str, Any]
) -> SQLColumnExpression:
"""Deprecated functionality. """Deprecated functionality.
This is for backwards compatibility with the JSON based schema for metadata. This is for backwards compatibility with the JSON based schema for metadata.
@ -819,7 +823,7 @@ class PGVector(VectorStore):
return filter_by_metadata return filter_by_metadata
def _create_filter_clause_json_deprecated( def _create_filter_clause_json_deprecated(
self, filter: Any self, filter: Mapping[str, Union[str, dict[str, Any]]]
) -> List[SQLColumnExpression]: ) -> List[SQLColumnExpression]:
"""Convert filters from IR to SQL clauses. """Convert filters from IR to SQL clauses.

View File

@ -2,12 +2,16 @@ import importlib
import os import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from pydantic import ConfigDict from pydantic import ConfigDict
from typing_extensions import Self
if TYPE_CHECKING:
from thirdai import neural_db as ndb
class NeuralDBVectorStore(VectorStore): class NeuralDBVectorStore(VectorStore):
@ -25,10 +29,10 @@ class NeuralDBVectorStore(VectorStore):
vectorstore = NeuralDBVectorStore(db=db) vectorstore = NeuralDBVectorStore(db=db)
""" """
def __init__(self, db: Any) -> None: def __init__(self, db: "ndb.NeuralDB") -> None:
self.db = db self.db = db
db: Any = None #: :meta private: db: "ndb.NeuralDB" = None #: :meta private:
"""NeuralDB instance""" """NeuralDB instance"""
model_config = ConfigDict( model_config = ConfigDict(
@ -36,7 +40,7 @@ class NeuralDBVectorStore(VectorStore):
) )
@staticmethod @staticmethod
def _verify_thirdai_library(thirdai_key: Optional[str] = None): # type: ignore[no-untyped-def] def _verify_thirdai_library(thirdai_key: Optional[str] = None) -> None:
try: try:
from thirdai import licensing from thirdai import licensing
@ -50,11 +54,11 @@ class NeuralDBVectorStore(VectorStore):
) )
@classmethod @classmethod
def from_scratch( # type: ignore[no-untyped-def, no-untyped-def] def from_scratch(
cls, cls,
thirdai_key: Optional[str] = None, thirdai_key: Optional[str] = None,
**model_kwargs, **model_kwargs: Any,
): ) -> Self:
""" """
Create a NeuralDBVectorStore from scratch. Create a NeuralDBVectorStore from scratch.
@ -84,11 +88,11 @@ class NeuralDBVectorStore(VectorStore):
return cls(db=ndb.NeuralDB(**model_kwargs)) # type: ignore[call-arg] return cls(db=ndb.NeuralDB(**model_kwargs)) # type: ignore[call-arg]
@classmethod @classmethod
def from_checkpoint( # type: ignore[no-untyped-def] def from_checkpoint(
cls, cls,
checkpoint: Union[str, Path], checkpoint: Union[str, Path],
thirdai_key: Optional[str] = None, thirdai_key: Optional[str] = None,
): ) -> Self:
""" """
Create a NeuralDBVectorStore with a base model from a saved checkpoint Create a NeuralDBVectorStore with a base model from a saved checkpoint
@ -163,13 +167,13 @@ class NeuralDBVectorStore(VectorStore):
offset = self.db._savable_state.documents.get_source_by_id(source_id)[1] offset = self.db._savable_state.documents.get_source_by_id(source_id)[1]
return [str(offset + i) for i in range(len(texts))] # type: ignore[arg-type] return [str(offset + i) for i in range(len(texts))] # type: ignore[arg-type]
def insert( # type: ignore[no-untyped-def, no-untyped-def] def insert(
self, self,
sources: List[Any], sources: list[Union[str, "ndb.Document"]],
train: bool = True, train: bool = True,
fast_mode: bool = True, fast_mode: bool = True,
**kwargs, **kwargs: Any,
): ) -> list[str]:
"""Inserts files / document sources into the vectorstore. """Inserts files / document sources into the vectorstore.
Args: Args:
@ -180,14 +184,16 @@ class NeuralDBVectorStore(VectorStore):
Defaults to True. Defaults to True.
""" """
sources = self._preprocess_sources(sources) sources = self._preprocess_sources(sources)
self.db.insert( return self.db.insert(
sources=sources, sources=sources,
train=train, train=train,
fast_approximation=fast_mode, fast_approximation=fast_mode,
**kwargs, **kwargs,
) )
def _preprocess_sources(self, sources): # type: ignore[no-untyped-def] def _preprocess_sources(
self, sources: list[Union[str, "ndb.Document"]]
) -> list["ndb.Document"]:
"""Checks if the provided sources are string paths. If they are, convert """Checks if the provided sources are string paths. If they are, convert
to NeuralDB document objects. to NeuralDB document objects.
@ -219,7 +225,7 @@ class NeuralDBVectorStore(VectorStore):
) )
return preprocessed_sources return preprocessed_sources
def upvote(self, query: str, document_id: Union[int, str]): # type: ignore[no-untyped-def] def upvote(self, query: str, document_id: Union[int, str]) -> None:
"""The vectorstore upweights the score of a document for a specific query. """The vectorstore upweights the score of a document for a specific query.
This is useful for fine-tuning the vectorstore to user behavior. This is useful for fine-tuning the vectorstore to user behavior.
@ -229,7 +235,7 @@ class NeuralDBVectorStore(VectorStore):
""" """
self.db.text_to_result(query, int(document_id)) self.db.text_to_result(query, int(document_id))
def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]): # type: ignore[no-untyped-def] def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]) -> None:
"""Given a batch of (query, document id) pairs, the vectorstore upweights """Given a batch of (query, document id) pairs, the vectorstore upweights
the scores of the document for the corresponding queries. the scores of the document for the corresponding queries.
This is useful for fine-tuning the vectorstore to user behavior. This is useful for fine-tuning the vectorstore to user behavior.
@ -242,7 +248,7 @@ class NeuralDBVectorStore(VectorStore):
[(query, int(doc_id)) for query, doc_id in query_id_pairs] [(query, int(doc_id)) for query, doc_id in query_id_pairs]
) )
def associate(self, source: str, target: str): # type: ignore[no-untyped-def] def associate(self, source: str, target: str) -> None:
"""The vectorstore associates a source phrase with a target phrase. """The vectorstore associates a source phrase with a target phrase.
When the vectorstore sees the source phrase, it will also consider results When the vectorstore sees the source phrase, it will also consider results
that are relevant to the target phrase. that are relevant to the target phrase.
@ -253,7 +259,7 @@ class NeuralDBVectorStore(VectorStore):
""" """
self.db.associate(source, target) self.db.associate(source, target)
def associate_batch(self, text_pairs: List[Tuple[str, str]]): # type: ignore[no-untyped-def] def associate_batch(self, text_pairs: List[Tuple[str, str]]) -> None:
"""Given a batch of (source, target) pairs, the vectorstore associates """Given a batch of (source, target) pairs, the vectorstore associates
each source phrase with the corresponding target phrase. each source phrase with the corresponding target phrase.
@ -291,7 +297,7 @@ class NeuralDBVectorStore(VectorStore):
except Exception as e: except Exception as e:
raise ValueError(f"Error while retrieving documents: {e}") from e raise ValueError(f"Error while retrieving documents: {e}") from e
def save(self, path: str): # type: ignore[no-untyped-def] def save(self, path: str) -> None:
"""Saves a NeuralDB instance to disk. Can be loaded into memory by """Saves a NeuralDB instance to disk. Can be loaded into memory by
calling NeuralDB.from_checkpoint(path) calling NeuralDB.from_checkpoint(path)
@ -325,10 +331,10 @@ class NeuralDBClientVectorStore(VectorStore):
""" """
def __init__(self, db: Any) -> None: def __init__(self, db: "ndb.NeuralDBClient") -> None:
self.db = db self.db = db
db: Any = None #: :meta private: db: "ndb.NeuralDBClient" = None #: :meta private:
"""NeuralDB Client instance""" """NeuralDB Client instance"""
model_config = ConfigDict( model_config = ConfigDict(
@ -362,7 +368,7 @@ class NeuralDBClientVectorStore(VectorStore):
except Exception as e: except Exception as e:
raise ValueError(f"Error while retrieving documents: {e}") from e raise ValueError(f"Error while retrieving documents: {e}") from e
def insert(self, documents: List[Dict[str, Any]]): # type: ignore[no-untyped-def, no-untyped-def] def insert(self, documents: List[Dict[str, Any]]) -> Any:
""" """
Inserts documents into the VectorStore and return the corresponding Sources. Inserts documents into the VectorStore and return the corresponding Sources.
@ -446,7 +452,7 @@ class NeuralDBClientVectorStore(VectorStore):
""" """
return self.db.insert(documents) return self.db.insert(documents)
def remove_documents(self, source_ids: List[str]): # type: ignore[no-untyped-def] def remove_documents(self, source_ids: list[str]) -> None:
""" """
Deletes documents from the VectorStore using source ids. Deletes documents from the VectorStore using source ids.

View File

@ -8,6 +8,7 @@ import numpy as np
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from typing_extensions import Self
from langchain_community.vectorstores.utils import maximal_marginal_relevance from langchain_community.vectorstores.utils import maximal_marginal_relevance
@ -31,7 +32,14 @@ class VikingDBConfig(object):
scheme(str):http or https, defaulting to http. scheme(str):http or https, defaulting to http.
""" """
def __init__(self, host="host", region="region", ak="ak", sk="sk", scheme="http"): # type: ignore[no-untyped-def] def __init__(
self,
host: str = "host",
region: str = "region",
ak: str = "ak",
sk: str = "sk",
scheme: str = "http",
) -> None:
self.host = host self.host = host
self.region = region self.region = region
self.ak = ak self.ak = ak
@ -397,7 +405,7 @@ class VikingDB(VectorStore):
self.collection.delete_data(ids) # type: ignore[union-attr] self.collection.delete_data(ids) # type: ignore[union-attr]
@classmethod @classmethod
def from_texts( # type: ignore[no-untyped-def, override] def from_texts( # type: ignore[override]
cls, cls,
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
@ -407,7 +415,7 @@ class VikingDB(VectorStore):
index_params: Optional[dict] = None, index_params: Optional[dict] = None,
drop_old: bool = False, drop_old: bool = False,
**kwargs: Any, **kwargs: Any,
): ) -> Self:
"""Create a collection, indexes it and insert data.""" """Create a collection, indexes it and insert data."""
if connection_args is None: if connection_args is None:
raise Exception("VikingDBConfig does not exists") raise Exception("VikingDBConfig does not exists")

View File

@ -57,7 +57,7 @@ def test_add_messages() -> None:
assert len(message_store_another.messages) == 0 assert len(message_store_another.messages) == 0
def test_tidb_recent_chat_message(): # type: ignore[no-untyped-def] def test_tidb_recent_chat_message() -> None:
"""Test the TiDBChatMessageHistory with earliest_time parameter.""" """Test the TiDBChatMessageHistory with earliest_time parameter."""
import time import time
from datetime import datetime from datetime import datetime

View File

@ -5,13 +5,16 @@ You can get a list of models from the bedrock client by running 'bedrock_models(
""" """
import os import os
from typing import Any from typing import TYPE_CHECKING, Any
import pytest import pytest
from langchain_core.callbacks import AsyncCallbackHandler from langchain_core.callbacks import AsyncCallbackHandler
from langchain_community.llms.bedrock import Bedrock from langchain_community.llms.bedrock import Bedrock
if TYPE_CHECKING:
from botocore.client import BaseClient
# this is the guardrails id for the model you want to test # this is the guardrails id for the model you want to test
GUARDRAILS_ID = os.environ.get("GUARDRAILS_ID", "7jarelix77") GUARDRAILS_ID = os.environ.get("GUARDRAILS_ID", "7jarelix77")
# this is the guardrails version for the model you want to test # this is the guardrails version for the model you want to test
@ -37,12 +40,12 @@ class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
if reason == "GUARDRAIL_INTERVENED": if reason == "GUARDRAIL_INTERVENED":
self.guardrails_intervened = True self.guardrails_intervened = True
def get_response(self): # type: ignore[no-untyped-def] def get_response(self) -> bool:
return self.guardrails_intervened return self.guardrails_intervened
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def bedrock_runtime_client(): # type: ignore[no-untyped-def] def bedrock_runtime_client() -> "BaseClient":
import boto3 import boto3
try: try:
@ -56,7 +59,7 @@ def bedrock_runtime_client(): # type: ignore[no-untyped-def]
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def bedrock_client(): # type: ignore[no-untyped-def] def bedrock_client() -> "BaseClient":
import boto3 import boto3
try: try:
@ -70,7 +73,7 @@ def bedrock_client(): # type: ignore[no-untyped-def]
@pytest.fixture @pytest.fixture
def bedrock_models(bedrock_client): # type: ignore[no-untyped-def] def bedrock_models(bedrock_client: "BaseClient") -> dict:
"""List bedrock models.""" """List bedrock models."""
response = bedrock_client.list_foundation_models().get("modelSummaries") response = bedrock_client.list_foundation_models().get("modelSummaries")
models = {} models = {}
@ -79,7 +82,9 @@ def bedrock_models(bedrock_client): # type: ignore[no-untyped-def]
return models return models
def test_claude_instant_v1(bedrock_runtime_client, bedrock_models): # type: ignore[no-untyped-def] def test_claude_instant_v1(
bedrock_runtime_client: "BaseClient", bedrock_models: dict
) -> None:
try: try:
llm = Bedrock( llm = Bedrock(
model_id="anthropic.claude-instant-v1", model_id="anthropic.claude-instant-v1",
@ -92,9 +97,9 @@ def test_claude_instant_v1(bedrock_runtime_client, bedrock_models): # type: ign
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False) pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
def test_amazon_bedrock_guardrails_no_intervention_for_valid_query( # type: ignore[no-untyped-def] def test_amazon_bedrock_guardrails_no_intervention_for_valid_query(
bedrock_runtime_client, bedrock_models bedrock_runtime_client: "BaseClient", bedrock_models: dict
): ) -> None:
try: try:
llm = Bedrock( llm = Bedrock(
model_id="anthropic.claude-instant-v1", model_id="anthropic.claude-instant-v1",
@ -112,9 +117,9 @@ def test_amazon_bedrock_guardrails_no_intervention_for_valid_query( # type: ign
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False) pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
def test_amazon_bedrock_guardrails_intervention_for_invalid_query( # type: ignore[no-untyped-def] def test_amazon_bedrock_guardrails_intervention_for_invalid_query(
bedrock_runtime_client, bedrock_models bedrock_runtime_client: "BaseClient", bedrock_models: dict
): ) -> None:
try: try:
handler = BedrockAsyncCallbackHandler() handler = BedrockAsyncCallbackHandler()
llm = Bedrock( llm = Bedrock(

View File

@ -1,13 +1,13 @@
"""Unit test for Google Trends API Wrapper.""" """Unit test for Google Trends API Wrapper."""
import os import os
from unittest.mock import patch from unittest.mock import MagicMock, patch
from langchain_community.utilities.google_trends import GoogleTrendsAPIWrapper from langchain_community.utilities.google_trends import GoogleTrendsAPIWrapper
@patch("serpapi.SerpApiClient.get_json") @patch("serpapi.SerpApiClient.get_json")
def test_unexpected_response(mocked_serpapiclient): # type: ignore[no-untyped-def] def test_unexpected_response(mocked_serpapiclient: MagicMock) -> None:
os.environ["SERPAPI_API_KEY"] = "123abcd" os.environ["SERPAPI_API_KEY"] = "123abcd"
resp = { resp = {
"search_metadata": { "search_metadata": {

View File

@ -15,7 +15,7 @@ def qdrant_is_not_running() -> bool:
return True return True
def assert_documents_equals(actual: List[Document], expected: List[Document]): # type: ignore[no-untyped-def] def assert_documents_equals(actual: List[Document], expected: List[Document]) -> None:
assert len(actual) == len(expected) assert len(actual) == len(expected)
for actual_doc, expected_doc in zip(actual, expected): for actual_doc, expected_doc in zip(actual, expected):

View File

@ -1,5 +1,7 @@
"""Test Deep Lake functionality.""" """Test Deep Lake functionality."""
from collections.abc import Iterator
import pytest import pytest
from langchain_core.documents import Document from langchain_core.documents import Document
from pytest import FixtureRequest from pytest import FixtureRequest
@ -9,7 +11,7 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
@pytest.fixture @pytest.fixture
def deeplake_datastore() -> DeepLake: # type: ignore[misc] def deeplake_datastore() -> Iterator[DeepLake]:
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
metadatas = [{"page": str(i)} for i in range(len(texts))] metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = DeepLake.from_texts( docsearch = DeepLake.from_texts(
@ -53,7 +55,7 @@ def test_deeplake_with_metadatas() -> None:
assert output == [Document(page_content="foo", metadata={"page": "0"})] assert output == [Document(page_content="foo", metadata={"page": "0"})]
def test_deeplake_with_persistence(deeplake_datastore) -> None: # type: ignore[no-untyped-def] def test_deeplake_with_persistence(deeplake_datastore: DeepLake) -> None:
"""Test end to end construction and search, with persistence.""" """Test end to end construction and search, with persistence."""
output = deeplake_datastore.similarity_search("foo", k=1) output = deeplake_datastore.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": "0"})] assert output == [Document(page_content="foo", metadata={"page": "0"})]
@ -73,7 +75,7 @@ def test_deeplake_with_persistence(deeplake_datastore) -> None: # type: ignore[
# Or on program exit # Or on program exit
def test_deeplake_overwrite_flag(deeplake_datastore) -> None: # type: ignore[no-untyped-def] def test_deeplake_overwrite_flag(deeplake_datastore: DeepLake) -> None:
"""Test overwrite behavior""" """Test overwrite behavior"""
dataset_path = deeplake_datastore.vectorstore.dataset_handler.path dataset_path = deeplake_datastore.vectorstore.dataset_handler.path
@ -109,7 +111,7 @@ def test_deeplake_overwrite_flag(deeplake_datastore) -> None: # type: ignore[no
output = docsearch.similarity_search("foo", k=1) output = docsearch.similarity_search("foo", k=1)
def test_similarity_search(deeplake_datastore) -> None: # type: ignore[no-untyped-def] def test_similarity_search(deeplake_datastore: DeepLake) -> None:
"""Test similarity search.""" """Test similarity search."""
distance_metric = "cos" distance_metric = "cos"
output = deeplake_datastore.similarity_search( output = deeplake_datastore.similarity_search(

View File

@ -2,6 +2,7 @@
import os import os
import random import random
from types import ModuleType
from typing import Any, Dict, List from typing import Any, Dict, List
import numpy as np import numpy as np
@ -29,7 +30,6 @@ TYPE_4B_FILTERING_TEST_CASES = [
), ),
] ]
try: try:
from hdbcli import dbapi from hdbcli import dbapi
@ -56,15 +56,15 @@ embedding = NormalizedFakeEmbeddings()
class ConfigData: class ConfigData:
def __init__(self): # type: ignore[no-untyped-def] def __init__(self) -> None:
self.conn = None self.conn: dbapi.Connection = None
self.schema_name = "" self.schema_name: str = ""
test_setup = ConfigData() test_setup = ConfigData()
def generateSchemaName(cursor): # type: ignore[no-untyped-def] def generateSchemaName(cursor: "dbapi.Cursor") -> str:
# return "Langchain" # return "Langchain"
cursor.execute( cursor.execute(
"SELECT REPLACE(CURRENT_UTCDATE, '-', '') || '_' || BINTOHEX(SYSUUID) FROM " "SELECT REPLACE(CURRENT_UTCDATE, '-', '') || '_' || BINTOHEX(SYSUUID) FROM "
@ -78,7 +78,7 @@ def generateSchemaName(cursor): # type: ignore[no-untyped-def]
return f"VEC_{uid}" return f"VEC_{uid}"
def setup_module(module): # type: ignore[no-untyped-def] def setup_module(module: ModuleType) -> None:
test_setup.conn = dbapi.connect( test_setup.conn = dbapi.connect(
address=os.environ.get("HANA_DB_ADDRESS"), address=os.environ.get("HANA_DB_ADDRESS"),
port=os.environ.get("HANA_DB_PORT"), port=os.environ.get("HANA_DB_PORT"),
@ -101,7 +101,7 @@ def setup_module(module): # type: ignore[no-untyped-def]
cur.close() cur.close()
def teardown_module(module): # type: ignore[no-untyped-def] def teardown_module(module: ModuleType) -> None:
# return # return
try: try:
cur = test_setup.conn.cursor() cur = test_setup.conn.cursor()
@ -119,17 +119,17 @@ def texts() -> List[str]:
@pytest.fixture @pytest.fixture
def metadatas() -> List[str]: def metadatas() -> List[dict[str, Any]]:
return [ return [
{"start": 0, "end": 100, "quality": "good", "ready": True}, # type: ignore[list-item] {"start": 0, "end": 100, "quality": "good", "ready": True},
{"start": 100, "end": 200, "quality": "bad", "ready": False}, # type: ignore[list-item] {"start": 100, "end": 200, "quality": "bad", "ready": False},
{"start": 200, "end": 300, "quality": "ugly", "ready": True}, # type: ignore[list-item] {"start": 200, "end": 300, "quality": "ugly", "ready": True},
{"start": 200, "quality": "ugly", "ready": True, "Owner": "Steve"}, # type: ignore[list-item] {"start": 200, "quality": "ugly", "ready": True, "Owner": "Steve"},
{"start": 300, "quality": "ugly", "Owner": "Steve"}, # type: ignore[list-item] {"start": 300, "quality": "ugly", "Owner": "Steve"},
] ]
def drop_table(connection, table_name): # type: ignore[no-untyped-def] def drop_table(connection: "dbapi.Connection", table_name: str) -> None:
try: try:
cur = connection.cursor() cur = connection.cursor()
sql_str = f"DROP TABLE {table_name}" sql_str = f"DROP TABLE {table_name}"
@ -279,7 +279,7 @@ def test_hanavector_non_existing_table_fixed_vector_length() -> None:
assert vectordb._table_exists(table_name) assert vectordb._table_exists(table_name)
vectordb._check_column( vectordb._check_column(
table_name, vector_column, "REAL_VECTOR", vector_column_length table_name, vector_column, ["REAL_VECTOR"], vector_column_length
) )

View File

@ -32,7 +32,7 @@ def fix_distance_precision(
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings): class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
"""Fake embeddings functionality for testing.""" """Fake embeddings functionality for testing."""
def __init__(self): # type: ignore[no-untyped-def] def __init__(self) -> None:
super(FakeEmbeddingsWithAdaDimension, self).__init__(size=ADA_TOKEN_COUNT) super(FakeEmbeddingsWithAdaDimension, self).__init__(size=ADA_TOKEN_COUNT)
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:

View File

@ -1,13 +1,15 @@
import os import os
import shutil import shutil
from collections.abc import Iterator
import pytest import pytest
from langchain_core.documents import Document
from langchain_community.vectorstores import NeuralDBVectorStore from langchain_community.vectorstores import NeuralDBVectorStore
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def test_csv(): # type: ignore[no-untyped-def] def test_csv() -> Iterator[str]:
csv = "thirdai-test.csv" csv = "thirdai-test.csv"
with open(csv, "w") as o: with open(csv, "w") as o:
o.write("column_1,column_2\n") o.write("column_1,column_2\n")
@ -16,13 +18,13 @@ def test_csv(): # type: ignore[no-untyped-def]
os.remove(csv) os.remove(csv)
def assert_result_correctness(documents): # type: ignore[no-untyped-def] def assert_result_correctness(documents: list[Document]) -> None:
assert len(documents) == 1 assert len(documents) == 1
assert documents[0].page_content == "column_1: column one\n\ncolumn_2: column two" assert documents[0].page_content == "column_1: column one\n\ncolumn_2: column two"
@pytest.mark.requires("thirdai[neural_db]") @pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_scratch(test_csv): # type: ignore[no-untyped-def] def test_neuraldb_retriever_from_scratch(test_csv: str) -> None:
retriever = NeuralDBVectorStore.from_scratch() retriever = NeuralDBVectorStore.from_scratch()
retriever.insert([test_csv]) retriever.insert([test_csv])
documents = retriever.similarity_search("column") documents = retriever.similarity_search("column")
@ -30,7 +32,7 @@ def test_neuraldb_retriever_from_scratch(test_csv): # type: ignore[no-untyped-d
@pytest.mark.requires("thirdai[neural_db]") @pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_checkpoint(test_csv): # type: ignore[no-untyped-def] def test_neuraldb_retriever_from_checkpoint(test_csv: str) -> None:
checkpoint = "thirdai-test-save.ndb" checkpoint = "thirdai-test-save.ndb"
if os.path.exists(checkpoint): if os.path.exists(checkpoint):
shutil.rmtree(checkpoint) shutil.rmtree(checkpoint)
@ -47,7 +49,7 @@ def test_neuraldb_retriever_from_checkpoint(test_csv): # type: ignore[no-untype
@pytest.mark.requires("thirdai[neural_db]") @pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_other_methods(test_csv): # type: ignore[no-untyped-def] def test_neuraldb_retriever_other_methods(test_csv: str) -> None:
retriever = NeuralDBVectorStore.from_scratch() retriever = NeuralDBVectorStore.from_scratch()
retriever.insert([test_csv]) retriever.insert([test_csv])
# Make sure they don't throw an error. # Make sure they don't throw an error.

View File

@ -268,7 +268,7 @@ def vectara3() -> Iterable[Vectara]:
vectara3.delete(doc_ids) vectara3.delete(doc_ids)
def test_vectara_with_langchain_mmr(vectara3: Vectara) -> None: # type: ignore[no-untyped-def] def test_vectara_with_langchain_mmr(vectara3: Vectara) -> None:
# test max marginal relevance # test max marginal relevance
output1 = vectara3.max_marginal_relevance_search( output1 = vectara3.max_marginal_relevance_search(
"generative AI", "generative AI",
@ -299,7 +299,7 @@ def test_vectara_with_langchain_mmr(vectara3: Vectara) -> None: # type: ignore[
) )
def test_vectara_rerankers(vectara3: Vectara) -> None: # type: ignore[no-untyped-def] def test_vectara_rerankers(vectara3: Vectara) -> None:
# test Vectara multi-lingual reranker # test Vectara multi-lingual reranker
summary_config = SummaryConfig(is_enabled=True, max_results=7, response_lang="eng") summary_config = SummaryConfig(is_enabled=True, max_results=7, response_lang="eng")
rerank_config = RerankConfig(reranker="rerank_multilingual_v1", rerank_k=50) rerank_config = RerankConfig(reranker="rerank_multilingual_v1", rerank_k=50)
@ -375,7 +375,7 @@ def test_vectara_rerankers(vectara3: Vectara) -> None: # type: ignore[no-untype
assert len(output2) > 0 assert len(output2) > 0
def test_vectara_with_summary(vectara3) -> None: # type: ignore[no-untyped-def] def test_vectara_with_summary(vectara3: Vectara) -> None:
"""Test vectara summary.""" """Test vectara summary."""
# test summarization # test summarization
num_results = 10 num_results = 10

View File

@ -43,7 +43,7 @@ def test_edenai_messages_formatting(messages: List[BaseMessage], expected: str)
("role", "role_response"), ("role", "role_response"),
[("ai", "assistant"), ("human", "user"), ("chat", "user")], [("ai", "assistant"), ("human", "user"), ("chat", "user")],
) )
def test_edenai_message_role(role: str, role_response) -> None: # type: ignore[no-untyped-def] def test_edenai_message_role(role: str, role_response: str) -> None:
role = _message_role(role) role = _message_role(role)
assert role == role_response assert role == role_response

View File

@ -2,7 +2,7 @@
import json import json
import os import os
from typing import Any, AsyncGenerator, Generator, cast from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, cast
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -20,6 +20,9 @@ from langchain_community.chat_models.naver import (
_convert_naver_chat_message_to_message, _convert_naver_chat_message_to_message,
) )
if TYPE_CHECKING:
from httpx_sse import ServerSentEvent
os.environ["NCP_CLOVASTUDIO_API_KEY"] = "test_api_key" os.environ["NCP_CLOVASTUDIO_API_KEY"] = "test_api_key"
os.environ["NCP_APIGW_API_KEY"] = "test_gw_key" os.environ["NCP_APIGW_API_KEY"] = "test_gw_key"
@ -131,7 +134,7 @@ async def test_naver_ainvoke(mock_chat_completion_response: dict) -> None:
assert completed assert completed
def _make_completion_response_from_token(token: str): # type: ignore[no-untyped-def] def _make_completion_response_from_token(token: str) -> "ServerSentEvent":
from httpx_sse import ServerSentEvent from httpx_sse import ServerSentEvent
return ServerSentEvent( return ServerSentEvent(

View File

@ -1,5 +1,6 @@
"""Test OCI Generative AI LLM service""" """Test OCI Generative AI LLM service"""
from typing import Any
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
@ -10,7 +11,7 @@ from langchain_community.chat_models.oci_generative_ai import ChatOCIGenAI
class MockResponseDict(dict): class MockResponseDict(dict):
def __getattr__(self, val): # type: ignore[no-untyped-def] def __getattr__(self, val: str) -> Any:
return self[val] return self[val]
@ -29,7 +30,7 @@ def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
provider = model_id.split(".")[0].lower() provider = model_id.split(".")[0].lower()
def mocked_response(*args): # type: ignore[no-untyped-def] def mocked_response(*args: Any) -> MockResponseDict:
response_text = "Assistant chat reply." response_text = "Assistant chat reply."
response = None response = None
if provider == "cohere": if provider == "cohere":
@ -102,6 +103,8 @@ def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
), ),
} }
) )
else:
raise ValueError(f"Unsupported provider {provider}")
return response return response
monkeypatch.setattr(llm.client, "chat", mocked_response) monkeypatch.setattr(llm.client, "chat", mocked_response)

View File

@ -20,7 +20,7 @@ class MockEmbed4All(MagicMock):
n_threads: Optional[int] = None, n_threads: Optional[int] = None,
device: Optional[str] = None, device: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
): # type: ignore[no-untyped-def] ):
assert model_name == _GPT4ALL_MODEL_NAME assert model_name == _GPT4ALL_MODEL_NAME

View File

@ -45,14 +45,14 @@ class GradientEmbeddingsModel(MagicMock):
output.embeddings = embeddings output.embeddings = embeddings
return output return output
async def aembed(self, *args) -> Any: # type: ignore[no-untyped-def] async def aembed(self, *args: Any) -> Any:
return self.embed(*args) return self.embed(*args)
class MockGradient(MagicMock): class MockGradient(MagicMock):
"""Mock Gradient package.""" """Mock Gradient package."""
def __init__(self, access_token: str, workspace_id, host): # type: ignore[no-untyped-def] def __init__(self, access_token: str, workspace_id: str, host: str) -> None:
assert access_token == _GRADIENT_SECRET assert access_token == _GRADIENT_SECRET
assert workspace_id == _GRADIENT_WORKSPACE_ID assert workspace_id == _GRADIENT_WORKSPACE_ID
assert host == _GRADIENT_BASE_URL assert host == _GRADIENT_BASE_URL

View File

@ -1,4 +1,5 @@
import json import json
from typing import Any
import numpy as np import numpy as np
import requests import requests
@ -23,7 +24,9 @@ def test_embed_documents(monkeypatch: MonkeyPatch) -> None:
base_url="http://llamafile-host:8080", base_url="http://llamafile-host:8080",
) )
def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def] def mock_post(
url: str, headers: dict[str, str], json: dict[str, Any], timeout: float
) -> requests.Response:
assert url == "http://llamafile-host:8080/embedding" assert url == "http://llamafile-host:8080/embedding"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -50,7 +53,9 @@ def test_embed_query(monkeypatch: MonkeyPatch) -> None:
base_url="http://llamafile-host:8080", base_url="http://llamafile-host:8080",
) )
def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def] def mock_post(
url: str, headers: dict[str, str], json: dict[str, Any], timeout: float
) -> None:
assert url == "http://llamafile-host:8080/embedding" assert url == "http://llamafile-host:8080/embedding"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",

View File

@ -1,5 +1,6 @@
"""Test OCI Generative AI embedding service.""" """Test OCI Generative AI embedding service."""
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
@ -7,9 +8,12 @@ from pytest import MonkeyPatch
from langchain_community.embeddings import OCIGenAIEmbeddings from langchain_community.embeddings import OCIGenAIEmbeddings
if TYPE_CHECKING:
from oci.generative_ai_inference.models import EmbedTextDetails
class MockResponseDict(dict): class MockResponseDict(dict):
def __getattr__(self, val): # type: ignore[no-untyped-def] def __getattr__(self, val: str) -> Any:
return self[val] return self[val]
@ -26,7 +30,7 @@ def test_embedding_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
client=oci_gen_ai_client, client=oci_gen_ai_client,
) )
def mocked_response(invocation_obj): # type: ignore[no-untyped-def] def mocked_response(invocation_obj: "EmbedTextDetails") -> MockResponseDict:
docs = invocation_obj.inputs docs = invocation_obj.inputs
embeddings = [] embeddings = []

View File

@ -1,14 +1,14 @@
from langchain_community.graphs.neo4j_graph import value_sanitize from langchain_community.graphs.neo4j_graph import value_sanitize
def test_value_sanitize_with_small_list(): # type: ignore[no-untyped-def] def test_value_sanitize_with_small_list() -> None:
small_list = list(range(15)) # list size > LIST_LIMIT small_list = list(range(15)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "small_list": small_list} input_dict = {"key1": "value1", "small_list": small_list}
expected_output = {"key1": "value1", "small_list": small_list} expected_output = {"key1": "value1", "small_list": small_list}
assert value_sanitize(input_dict) == expected_output assert value_sanitize(input_dict) == expected_output
def test_value_sanitize_with_oversized_list(): # type: ignore[no-untyped-def] def test_value_sanitize_with_oversized_list() -> None:
oversized_list = list(range(150)) # list size > LIST_LIMIT oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": oversized_list} input_dict = {"key1": "value1", "oversized_list": oversized_list}
expected_output = { expected_output = {
@ -18,21 +18,21 @@ def test_value_sanitize_with_oversized_list(): # type: ignore[no-untyped-def]
assert value_sanitize(input_dict) == expected_output assert value_sanitize(input_dict) == expected_output
def test_value_sanitize_with_nested_oversized_list(): # type: ignore[no-untyped-def] def test_value_sanitize_with_nested_oversized_list() -> None:
oversized_list = list(range(150)) # list size > LIST_LIMIT oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": {"key": oversized_list}} input_dict = {"key1": "value1", "oversized_list": {"key": oversized_list}}
expected_output = {"key1": "value1", "oversized_list": {}} expected_output = {"key1": "value1", "oversized_list": {}}
assert value_sanitize(input_dict) == expected_output assert value_sanitize(input_dict) == expected_output
def test_value_sanitize_with_dict_in_list(): # type: ignore[no-untyped-def] def test_value_sanitize_with_dict_in_list() -> None:
oversized_list = list(range(150)) # list size > LIST_LIMIT oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": [1, 2, {"key": oversized_list}]} input_dict = {"key1": "value1", "oversized_list": [1, 2, {"key": oversized_list}]}
expected_output = {"key1": "value1", "oversized_list": [1, 2, {}]} expected_output = {"key1": "value1", "oversized_list": [1, 2, {}]}
assert value_sanitize(input_dict) == expected_output assert value_sanitize(input_dict) == expected_output
def test_value_sanitize_with_dict_in_nested_list(): # type: ignore[no-untyped-def] def test_value_sanitize_with_dict_in_nested_list() -> None:
input_dict = { input_dict = {
"key1": "value1", "key1": "value1",
"deeply_nested_lists": [[[[{"final_nested_key": list(range(200))}]]]], "deeply_nested_lists": [[[[{"final_nested_key": list(range(200))}]]]],

View File

@ -1,6 +1,6 @@
import json import json
from collections import deque from collections import deque
from typing import Any, Dict from typing import Any, Dict, Optional
import pytest import pytest
import requests import requests
@ -39,7 +39,7 @@ def mock_response() -> requests.Response:
return response return response
def mock_response_stream(): # type: ignore[no-untyped-def] def mock_response_stream() -> requests.Response:
mock_response = deque( mock_response = deque(
[ [
b'data: {"content":"the","multimodal":false,"slot_id":0,"stop":false}\n\n', b'data: {"content":"the","multimodal":false,"slot_id":0,"stop":false}\n\n',
@ -48,7 +48,7 @@ def mock_response_stream(): # type: ignore[no-untyped-def]
) )
class MockRaw: class MockRaw:
def read(self, chunk_size): # type: ignore[no-untyped-def] def read(self, chunk_size: int) -> Optional[bytes]:
try: try:
return mock_response.popleft() return mock_response.popleft()
except IndexError: except IndexError:
@ -68,7 +68,13 @@ def test_call(monkeypatch: MonkeyPatch) -> None:
base_url="http://llamafile-host:8080", base_url="http://llamafile-host:8080",
) )
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
) -> requests.Response:
assert url == "http://llamafile-host:8080/completion" assert url == "http://llamafile-host:8080/completion"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -94,7 +100,13 @@ def test_call_with_kwargs(monkeypatch: MonkeyPatch) -> None:
base_url="http://llamafile-host:8080", base_url="http://llamafile-host:8080",
) )
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
) -> requests.Response:
assert url == "http://llamafile-host:8080/completion" assert url == "http://llamafile-host:8080/completion"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -138,7 +150,13 @@ def test_streaming(monkeypatch: MonkeyPatch) -> None:
streaming=True, streaming=True,
) )
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
) -> requests.Response:
assert url == "http://llamafile-hostname:8080/completion" assert url == "http://llamafile-hostname:8080/completion"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",

View File

@ -1,5 +1,6 @@
"""Test OCI Generative AI LLM service""" """Test OCI Generative AI LLM service"""
from typing import Any
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
@ -9,7 +10,7 @@ from langchain_community.llms.oci_generative_ai import OCIGenAI
class MockResponseDict(dict): class MockResponseDict(dict):
def __getattr__(self, val): # type: ignore[no-untyped-def] def __getattr__(self, val: Any) -> Any:
return self[val] return self[val]
@ -29,7 +30,7 @@ def test_llm_complete(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
provider = model_id.split(".")[0].lower() provider = model_id.split(".")[0].lower()
def mocked_response(*args): # type: ignore[no-untyped-def] def mocked_response(*args: Any) -> MockResponseDict:
response_text = "This is the completion." response_text = "This is the completion."
if provider == "cohere": if provider == "cohere":
@ -75,6 +76,7 @@ def test_llm_complete(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
), ),
} }
) )
raise ValueError("Unsupported provider")
monkeypatch.setattr(llm.client, "generate_text", mocked_response) monkeypatch.setattr(llm.client, "generate_text", mocked_response)
output = llm.invoke("This is a prompt.", temperature=0.2) output = llm.invoke("This is a prompt.", temperature=0.2)

View File

@ -1,14 +1,16 @@
from typing import Any, Optional
import requests import requests
from pytest import MonkeyPatch from pytest import MonkeyPatch
from langchain_community.llms.ollama import Ollama from langchain_community.llms.ollama import Ollama
def mock_response_stream(): # type: ignore[no-untyped-def] def mock_response_stream() -> requests.Response:
mock_response = [b'{ "response": "Response chunk 1" }'] mock_response = [b'{ "response": "Response chunk 1" }']
class MockRaw: class MockRaw:
def read(self, chunk_size): # type: ignore[no-untyped-def] def read(self, chunk_size: int) -> Optional[bytes]:
try: try:
return mock_response.pop() return mock_response.pop()
except IndexError: except IndexError:
@ -31,7 +33,14 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
timeout=300, timeout=300,
) )
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
auth: tuple[str, str],
) -> requests.Response:
assert url == "https://ollama-hostname:8000/api/generate" assert url == "https://ollama-hostname:8000/api/generate"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -57,7 +66,14 @@ def test_pass_auth_if_provided(monkeypatch: MonkeyPatch) -> None:
timeout=300, timeout=300,
) )
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
auth: tuple[str, str],
) -> requests.Response:
assert url == "https://ollama-hostname:8000/api/generate" assert url == "https://ollama-hostname:8000/api/generate"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -77,7 +93,14 @@ def test_pass_auth_if_provided(monkeypatch: MonkeyPatch) -> None:
def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None: def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
auth: tuple[str, str],
) -> requests.Response:
assert url == "https://ollama-hostname:8000/api/generate" assert url == "https://ollama-hostname:8000/api/generate"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -97,7 +120,14 @@ def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None:
"""Test that top level params are sent to the endpoint as top level params""" """Test that top level params are sent to the endpoint as top level params"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
auth: tuple[str, str],
) -> requests.Response:
assert url == "https://ollama-hostname:8000/api/generate" assert url == "https://ollama-hostname:8000/api/generate"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -145,7 +175,14 @@ def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None:
""" """
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
auth: tuple[str, str],
) -> requests.Response:
assert url == "https://ollama-hostname:8000/api/generate" assert url == "https://ollama-hostname:8000/api/generate"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -194,7 +231,14 @@ def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None:
""" """
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] def mock_post(
url: str,
headers: dict[str, str],
json: dict[str, Any],
stream: bool,
timeout: Optional[float],
auth: tuple[str, str],
) -> requests.Response:
assert url == "https://ollama-hostname:8000/api/generate" assert url == "https://ollama-hostname:8000/api/generate"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",

View File

@ -1,5 +1,6 @@
"""Test building the Zapier tool, not running it.""" """Test building the Zapier tool, not running it."""
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@ -9,6 +10,9 @@ from langchain_community.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT
from langchain_community.tools.zapier.tool import ZapierNLARunAction from langchain_community.tools.zapier.tool import ZapierNLARunAction
from langchain_community.utilities.zapier import ZapierNLAWrapper from langchain_community.utilities.zapier import ZapierNLAWrapper
if TYPE_CHECKING:
from pytest_mock import MockerFixture
def test_default_base_prompt() -> None: def test_default_base_prompt() -> None:
"""Test that the default prompt is being inserted.""" """Test that the default prompt is being inserted."""
@ -136,7 +140,7 @@ def test_create_action_payload_with_params() -> None:
assert payload["test"] == "test" assert payload["test"] == "test"
async def test_apreview(mocker) -> None: # type: ignore[no-untyped-def] async def test_apreview(mocker: "MockerFixture") -> None:
"""Test that the action payload with params is being created correctly.""" """Test that the action payload with params is being created correctly."""
tool = ZapierNLARunAction( tool = ZapierNLARunAction(
action_id="test", action_id="test",
@ -164,7 +168,7 @@ async def test_apreview(mocker) -> None: # type: ignore[no-untyped-def]
) )
async def test_arun(mocker) -> None: # type: ignore[no-untyped-def] async def test_arun(mocker: "MockerFixture") -> None:
"""Test that the action payload with params is being created correctly.""" """Test that the action payload with params is being created correctly."""
tool = ZapierNLARunAction( tool = ZapierNLARunAction(
action_id="test", action_id="test",
@ -188,7 +192,7 @@ async def test_arun(mocker) -> None: # type: ignore[no-untyped-def]
) )
async def test_alist(mocker) -> None: # type: ignore[no-untyped-def] async def test_alist(mocker: "MockerFixture") -> None:
"""Test that the action payload with params is being created correctly.""" """Test that the action payload with params is being created correctly."""
tool = ZapierNLARunAction( tool = ZapierNLARunAction(
action_id="test", action_id="test",

View File

@ -1,5 +1,5 @@
import json import json
from typing import Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -9,6 +9,9 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
DEFAULT_VECTOR_DIMENSION = 4 DEFAULT_VECTOR_DIMENSION = 4
if TYPE_CHECKING:
from azure.search.documents.indexes.models import SearchIndex
class FakeEmbeddingsWithDimension(FakeEmbeddings): class FakeEmbeddingsWithDimension(FakeEmbeddings):
"""Fake embeddings functionality for testing.""" """Fake embeddings functionality for testing."""
@ -36,7 +39,7 @@ DEFAULT_ACCESS_TOKEN = "myaccesstoken1"
DEFAULT_EMBEDDING_MODEL = FakeEmbeddingsWithDimension() DEFAULT_EMBEDDING_MODEL = FakeEmbeddingsWithDimension()
def mock_default_index(*args, **kwargs): # type: ignore[no-untyped-def] def mock_default_index(*args: Any, **kwargs: Any) -> "SearchIndex":
from azure.search.documents.indexes.models import ( from azure.search.documents.indexes.models import (
ExhaustiveKnnAlgorithmConfiguration, ExhaustiveKnnAlgorithmConfiguration,
ExhaustiveKnnParameters, ExhaustiveKnnParameters,
@ -155,12 +158,12 @@ def test_init_new_index() -> None:
from azure.search.documents.indexes import SearchIndexClient from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import SearchIndex from azure.search.documents.indexes.models import SearchIndex
def no_index(self, name: str): # type: ignore[no-untyped-def] def no_index(self: SearchIndexClient, name: str) -> SearchIndex:
raise ResourceNotFoundError raise ResourceNotFoundError
created_index: Optional[SearchIndex] = None created_index: Optional[SearchIndex] = None
def mock_create_index(self, index): # type: ignore[no-untyped-def] def mock_create_index(self: SearchIndexClient, index: SearchIndex) -> None:
nonlocal created_index nonlocal created_index
created_index = index created_index = index
@ -203,7 +206,9 @@ def test_ids_used_correctly() -> None:
def __init__(self) -> None: def __init__(self) -> None:
self.succeeded: bool = True self.succeeded: bool = True
def mock_upload_documents(self, documents: List[object]) -> List[Response]: # type: ignore[no-untyped-def] def mock_upload_documents(
self: SearchClient, documents: List[object]
) -> List[Response]:
# assume all documents uploaded successfuly # assume all documents uploaded successfuly
response = [Response() for _ in documents] response = [Response() for _ in documents]
return response return response