mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-12 14:23:58 +00:00
community: Remove no-untyped-def escapes
This commit is contained in:
parent
cf2697ec53
commit
a60f82b1e2
@ -575,8 +575,8 @@ class ChatBaichuan(BaseChatModel):
|
||||
)
|
||||
return res
|
||||
|
||||
def _create_payload_parameters( # type: ignore[no-untyped-def]
|
||||
self, messages: List[BaseMessage], **kwargs
|
||||
def _create_payload_parameters(
|
||||
self, messages: List[BaseMessage], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
parameters = {**self._default_params, **kwargs}
|
||||
temperature = parameters.pop("temperature", 0.3)
|
||||
@ -600,7 +600,7 @@ class ChatBaichuan(BaseChatModel):
|
||||
|
||||
return payload
|
||||
|
||||
def _create_headers_parameters(self, **kwargs) -> Dict[str, Any]: # type: ignore[no-untyped-def]
|
||||
def _create_headers_parameters(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
parameters = {**self._default_params, **kwargs}
|
||||
default_headers = parameters.pop("headers", {})
|
||||
api_key = ""
|
||||
|
@ -439,8 +439,8 @@ class MiniMaxChat(BaseChatModel):
|
||||
}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _create_payload_parameters( # type: ignore[no-untyped-def]
|
||||
self, messages: List[BaseMessage], is_stream: bool = False, **kwargs
|
||||
def _create_payload_parameters(
|
||||
self, messages: List[BaseMessage], is_stream: bool = False, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Create API request body parameters."""
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
|
@ -1,5 +1,8 @@
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
|
||||
|
||||
@ -65,10 +68,15 @@ class CHMParser(object):
|
||||
self.file = chm.CHMFile()
|
||||
self.file.LoadCHM(path)
|
||||
|
||||
def __enter__(self): # type: ignore[no-untyped-def]
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback): # type: ignore[no-untyped-def]
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
if self.file:
|
||||
self.file.CloseCHM()
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional, Sequence, Union
|
||||
from typing import TYPE_CHECKING, Iterator, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@ -8,6 +8,9 @@ from langchain_community.document_loaders.base import BaseLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import mwxml
|
||||
|
||||
|
||||
class MWDumpLoader(BaseLoader):
|
||||
"""Load `MediaWiki` dump from an `XML` file.
|
||||
@ -60,7 +63,7 @@ class MWDumpLoader(BaseLoader):
|
||||
self.skip_redirects = skip_redirects
|
||||
self.stop_on_error = stop_on_error
|
||||
|
||||
def _load_dump_file(self): # type: ignore[no-untyped-def]
|
||||
def _load_dump_file(self) -> "mwxml.Dump":
|
||||
try:
|
||||
import mwxml
|
||||
except ImportError as e:
|
||||
@ -70,7 +73,7 @@ class MWDumpLoader(BaseLoader):
|
||||
|
||||
return mwxml.Dump.from_file(open(self.file_path, encoding=self.encoding))
|
||||
|
||||
def _load_single_page_from_dump(self, page) -> Document: # type: ignore[no-untyped-def, return]
|
||||
def _load_single_page_from_dump(self, page: "mwxml.Page") -> Document: # type: ignore[return]
|
||||
"""Parse a single page."""
|
||||
try:
|
||||
import mwparserfromhell
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@ -34,20 +34,24 @@ class AzureAIDocumentIntelligenceParser(BaseBlobParser):
|
||||
|
||||
kwargs = {}
|
||||
|
||||
if api_key is None and azure_credential is None:
|
||||
credential: Union[AzureKeyCredential, TokenCredential]
|
||||
if azure_credential:
|
||||
if api_key is not None:
|
||||
raise ValueError(
|
||||
"Only one of api_key or azure_credential should be provided."
|
||||
)
|
||||
credential = azure_credential
|
||||
elif api_key is not None:
|
||||
credential = AzureKeyCredential(api_key)
|
||||
else:
|
||||
raise ValueError("Either api_key or azure_credential must be provided.")
|
||||
|
||||
if api_key and azure_credential:
|
||||
raise ValueError(
|
||||
"Only one of api_key or azure_credential should be provided."
|
||||
)
|
||||
|
||||
if api_version is not None:
|
||||
kwargs["api_version"] = api_version
|
||||
|
||||
self.client = DocumentIntelligenceClient(
|
||||
endpoint=api_endpoint,
|
||||
credential=azure_credential or AzureKeyCredential(api_key),
|
||||
credential=credential,
|
||||
headers={"x-ms-useragent": "langchain-parser/1.0.0"},
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -169,5 +169,5 @@ class TinyAsyncGradientEmbeddingClient: #: :meta private:
|
||||
It might be entirely removed in the future.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise ValueError("Deprecated,TinyAsyncGradientEmbeddingClient was removed.")
|
||||
|
@ -1,10 +1,13 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils import pre_init
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import oci
|
||||
|
||||
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
|
||||
|
||||
|
||||
@ -122,12 +125,14 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings):
|
||||
client_kwargs.pop("signer", None)
|
||||
elif values["auth_type"] == OCIAuthType(2).name:
|
||||
|
||||
def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
|
||||
def make_security_token_signer(
|
||||
oci_config: dict[str, Any],
|
||||
) -> "oci.auth.signers.SecurityTokenSigner":
|
||||
pk = oci.signer.load_private_key_from_file(
|
||||
oci_config.get("key_file"), None
|
||||
)
|
||||
with open(
|
||||
oci_config.get("security_token_file"), encoding="utf-8"
|
||||
str(oci_config.get("security_token_file")), encoding="utf-8"
|
||||
) as f:
|
||||
st_string = f.read()
|
||||
return oci.auth.signers.SecurityTokenSigner(st_string, pk)
|
||||
|
@ -159,18 +159,20 @@ def _create_retry_decorator(llm: YandexGPTEmbeddings) -> Callable[[Any], Any]:
|
||||
)
|
||||
|
||||
|
||||
def _embed_with_retry(llm: YandexGPTEmbeddings, **kwargs: Any) -> Any:
|
||||
def _embed_with_retry(llm: YandexGPTEmbeddings, **kwargs: Any) -> list[list[float]]:
|
||||
"""Use tenacity to retry the embedding call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**_kwargs: Any) -> Any:
|
||||
def _completion_with_retry(**_kwargs: Any) -> list[list[float]]:
|
||||
return _make_request(llm, **_kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _make_request(self: YandexGPTEmbeddings, texts: List[str], **kwargs): # type: ignore[no-untyped-def]
|
||||
def _make_request(
|
||||
self: YandexGPTEmbeddings, texts: List[str], **kwargs: Any
|
||||
) -> list[list[float]]:
|
||||
try:
|
||||
import grpc
|
||||
|
||||
|
@ -93,6 +93,7 @@ class OntotextGraphDBGraph:
|
||||
self.graph = rdflib.Graph(store, identifier=None, bind_namespaces="none")
|
||||
self._check_connectivity()
|
||||
|
||||
ontology_schema_graph: "rdflib.Graph"
|
||||
if local_file:
|
||||
ontology_schema_graph = self._load_ontology_schema_from_file(
|
||||
local_file,
|
||||
@ -140,7 +141,9 @@ class OntotextGraphDBGraph:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _load_ontology_schema_from_file(local_file: str, local_file_format: str = None): # type: ignore[no-untyped-def, assignment]
|
||||
def _load_ontology_schema_from_file(
|
||||
local_file: str, local_file_format: Optional[str] = None
|
||||
) -> "rdflib.ConjunctiveGraph": # type: ignore[assignment]
|
||||
"""
|
||||
Parse the ontology schema statements from the provided file
|
||||
"""
|
||||
@ -177,7 +180,7 @@ class OntotextGraphDBGraph:
|
||||
"Invalid query type. Only CONSTRUCT queries are supported."
|
||||
)
|
||||
|
||||
def _load_ontology_schema_with_query(self, query: str): # type: ignore[no-untyped-def]
|
||||
def _load_ontology_schema_with_query(self, query: str) -> "rdflib.Graph":
|
||||
"""
|
||||
Execute the query for collecting the ontology schema statements
|
||||
"""
|
||||
@ -188,6 +191,9 @@ class OntotextGraphDBGraph:
|
||||
except ParserError as e:
|
||||
raise ValueError(f"Generated SPARQL statement is invalid\n{e}")
|
||||
|
||||
if not results.graph:
|
||||
raise ValueError("Missing graph in results.")
|
||||
|
||||
return results.graph
|
||||
|
||||
@property
|
||||
|
@ -77,7 +77,7 @@ class TigerGraph(GraphStore):
|
||||
"""
|
||||
return self._conn.getSchema(force=True)
|
||||
|
||||
def refresh_schema(self): # type: ignore[no-untyped-def]
|
||||
def refresh_schema(self) -> None:
|
||||
self.generate_schema()
|
||||
|
||||
def query(self, query: str) -> Dict[str, Any]: # type: ignore[override]
|
||||
|
@ -136,12 +136,14 @@ class OCIGenAIBase(BaseModel, ABC):
|
||||
client_kwargs.pop("signer", None)
|
||||
elif values["auth_type"] == OCIAuthType(2).name:
|
||||
|
||||
def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
|
||||
def make_security_token_signer(
|
||||
oci_config: dict[str, Any],
|
||||
) -> "oci.auth.signers.SecurityTokenSigner":
|
||||
pk = oci.signer.load_private_key_from_file(
|
||||
oci_config.get("key_file"), None
|
||||
)
|
||||
with open(
|
||||
oci_config.get("security_token_file"), encoding="utf-8"
|
||||
str(oci_config.get("security_token_file")), encoding="utf-8"
|
||||
) as f:
|
||||
st_string = f.read()
|
||||
return oci.auth.signers.SecurityTokenSigner(st_string, pk)
|
||||
|
@ -77,8 +77,11 @@ class ZapierNLAWrapper(BaseModel):
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
|
||||
def _create_action_payload( # type: ignore[no-untyped-def]
|
||||
self, instructions: str, params: Optional[Dict] = None, preview_only=False
|
||||
def _create_action_payload(
|
||||
self,
|
||||
instructions: str,
|
||||
params: Optional[Dict] = None,
|
||||
preview_only: bool = False,
|
||||
) -> Dict:
|
||||
"""Create a payload for an action."""
|
||||
data = params if params else {}
|
||||
@ -95,12 +98,12 @@ class ZapierNLAWrapper(BaseModel):
|
||||
"""Create a url for an action."""
|
||||
return self.zapier_nla_api_base + f"exposed/{action_id}/execute/"
|
||||
|
||||
def _create_action_request( # type: ignore[no-untyped-def]
|
||||
def _create_action_request(
|
||||
self,
|
||||
action_id: str,
|
||||
instructions: str,
|
||||
params: Optional[Dict] = None,
|
||||
preview_only=False,
|
||||
preview_only: bool = False,
|
||||
) -> Request:
|
||||
data = self._create_action_payload(instructions, params, preview_only)
|
||||
return Request(
|
||||
@ -259,39 +262,37 @@ class ZapierNLAWrapper(BaseModel):
|
||||
)
|
||||
return response["result"]
|
||||
|
||||
def run_as_str(self, *args, **kwargs) -> str: # type: ignore[no-untyped-def]
|
||||
def run_as_str(self, *args: Any, **kwargs: Any) -> str:
|
||||
"""Same as run, but returns a stringified version of the JSON for
|
||||
insertting back into an LLM."""
|
||||
data = self.run(*args, **kwargs)
|
||||
return json.dumps(data)
|
||||
|
||||
async def arun_as_str(self, *args, **kwargs) -> str: # type: ignore[no-untyped-def]
|
||||
async def arun_as_str(self, *args: Any, **kwargs: Any) -> str:
|
||||
"""Same as run, but returns a stringified version of the JSON for
|
||||
insertting back into an LLM."""
|
||||
data = await self.arun(*args, **kwargs)
|
||||
return json.dumps(data)
|
||||
|
||||
def preview_as_str(self, *args, **kwargs) -> str: # type: ignore[no-untyped-def]
|
||||
def preview_as_str(self, *args: Any, **kwargs: Any) -> str:
|
||||
"""Same as preview, but returns a stringified version of the JSON for
|
||||
insertting back into an LLM."""
|
||||
data = self.preview(*args, **kwargs)
|
||||
return json.dumps(data)
|
||||
|
||||
async def apreview_as_str( # type: ignore[no-untyped-def]
|
||||
self, *args, **kwargs
|
||||
) -> str:
|
||||
async def apreview_as_str(self, *args: Any, **kwargs: Any) -> str:
|
||||
"""Same as preview, but returns a stringified version of the JSON for
|
||||
insertting back into an LLM."""
|
||||
data = await self.apreview(*args, **kwargs)
|
||||
return json.dumps(data)
|
||||
|
||||
def list_as_str(self) -> str: # type: ignore[no-untyped-def]
|
||||
def list_as_str(self) -> str:
|
||||
"""Same as list, but returns a stringified version of the JSON for
|
||||
insertting back into an LLM."""
|
||||
actions = self.list()
|
||||
return json.dumps(actions)
|
||||
|
||||
async def alist_as_str(self) -> str: # type: ignore[no-untyped-def]
|
||||
async def alist_as_str(self) -> str:
|
||||
"""Same as list, but returns a stringified version of the JSON for
|
||||
insertting back into an LLM."""
|
||||
actions = await self.alist()
|
||||
|
@ -7,6 +7,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -22,6 +23,7 @@ from typing import (
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
@ -80,6 +82,54 @@ FIELDS_METADATA = get_from_env(
|
||||
MAX_UPLOAD_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
@overload
|
||||
def _get_search_client(
|
||||
endpoint: str,
|
||||
index_name: str,
|
||||
key: Optional[str] = None,
|
||||
azure_ad_access_token: Optional[str] = None,
|
||||
semantic_configuration_name: Optional[str] = None,
|
||||
fields: Optional[List[SearchField]] = None,
|
||||
vector_search: Optional[VectorSearch] = None,
|
||||
semantic_configurations: Optional[
|
||||
Union[SemanticConfiguration, List[SemanticConfiguration]]
|
||||
] = None,
|
||||
scoring_profiles: Optional[List[ScoringProfile]] = None,
|
||||
default_scoring_profile: Optional[str] = None,
|
||||
default_fields: Optional[List[SearchField]] = None,
|
||||
user_agent: Optional[str] = "langchain-comm-python-azure-search",
|
||||
cors_options: Optional[CorsOptions] = None,
|
||||
async_: Literal[False] = False,
|
||||
additional_search_client_options: Optional[Dict[str, Any]] = None,
|
||||
azure_credential: Optional[TokenCredential] = None,
|
||||
azure_async_credential: Optional[AsyncTokenCredential] = None,
|
||||
) -> Union[SearchClient]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def _get_search_client(
|
||||
endpoint: str,
|
||||
index_name: str,
|
||||
key: Optional[str] = None,
|
||||
azure_ad_access_token: Optional[str] = None,
|
||||
semantic_configuration_name: Optional[str] = None,
|
||||
fields: Optional[List[SearchField]] = None,
|
||||
vector_search: Optional[VectorSearch] = None,
|
||||
semantic_configurations: Optional[
|
||||
Union[SemanticConfiguration, List[SemanticConfiguration]]
|
||||
] = None,
|
||||
scoring_profiles: Optional[List[ScoringProfile]] = None,
|
||||
default_scoring_profile: Optional[str] = None,
|
||||
default_fields: Optional[List[SearchField]] = None,
|
||||
user_agent: Optional[str] = "langchain-comm-python-azure-search",
|
||||
cors_options: Optional[CorsOptions] = None,
|
||||
async_: Literal[True] = True,
|
||||
additional_search_client_options: Optional[Dict[str, Any]] = None,
|
||||
azure_credential: Optional[TokenCredential] = None,
|
||||
azure_async_credential: Optional[AsyncTokenCredential] = None,
|
||||
) -> Union[AsyncSearchClient]: ...
|
||||
|
||||
|
||||
def _get_search_client(
|
||||
endpoint: str,
|
||||
index_name: str,
|
||||
@ -102,6 +152,7 @@ def _get_search_client(
|
||||
azure_async_credential: Optional[AsyncTokenCredential] = None,
|
||||
) -> Union[SearchClient, AsyncSearchClient]:
|
||||
from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
|
||||
from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential
|
||||
@ -139,22 +190,54 @@ def _get_search_client(
|
||||
) -> AccessToken:
|
||||
return self._token
|
||||
|
||||
class AsyncTokenCredentialWrapper(AsyncTokenCredential):
|
||||
def __init__(self, credential: TokenCredential):
|
||||
self._credential = credential
|
||||
|
||||
async def get_token(
|
||||
self,
|
||||
*scopes: str,
|
||||
claims: Optional[str] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
enable_cae: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> AccessToken:
|
||||
return self._credential.get_token(
|
||||
*scopes,
|
||||
claims=claims,
|
||||
tenant_id=tenant_id,
|
||||
enable_cae=enable_cae,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
additional_search_client_options = additional_search_client_options or {}
|
||||
default_fields = default_fields or []
|
||||
credential: Union[AzureKeyCredential, TokenCredential, InteractiveBrowserCredential]
|
||||
credential: Union[AzureKeyCredential, TokenCredential]
|
||||
async_credential: Union[AzureKeyCredential, AsyncTokenCredential]
|
||||
|
||||
# Determine the appropriate credential to use
|
||||
if key is not None:
|
||||
if key.upper() == "INTERACTIVE":
|
||||
credential = InteractiveBrowserCredential()
|
||||
credential = cast("TokenCredential", InteractiveBrowserCredential())
|
||||
credential.get_token("https://search.azure.com/.default")
|
||||
async_credential = credential
|
||||
async_credential = AsyncTokenCredentialWrapper(credential)
|
||||
else:
|
||||
credential = AzureKeyCredential(key)
|
||||
async_credential = credential
|
||||
elif azure_ad_access_token is not None:
|
||||
credential = AzureBearerTokenCredential(azure_ad_access_token)
|
||||
async_credential = credential
|
||||
async_credential = AsyncTokenCredentialWrapper(credential)
|
||||
else:
|
||||
credential = azure_credential or DefaultAzureCredential()
|
||||
async_credential = azure_async_credential or AsyncDefaultAzureCredential()
|
||||
@ -1121,7 +1204,7 @@ class AzureSearch(VectorStore):
|
||||
search_text=text_query,
|
||||
vector_queries=[
|
||||
VectorizedQuery(
|
||||
vector=np.array(embedding, dtype=np.float32).tolist(),
|
||||
vector=embedding,
|
||||
k_nearest_neighbors=k,
|
||||
fields=FIELDS_CONTENT_VECTOR,
|
||||
)
|
||||
@ -1157,7 +1240,7 @@ class AzureSearch(VectorStore):
|
||||
search_text=text_query,
|
||||
vector_queries=[
|
||||
VectorizedQuery(
|
||||
vector=np.array(embedding, dtype=np.float32).tolist(),
|
||||
vector=embedding,
|
||||
k_nearest_neighbors=k,
|
||||
fields=FIELDS_CONTENT_VECTOR,
|
||||
)
|
||||
@ -1302,7 +1385,7 @@ class AzureSearch(VectorStore):
|
||||
search_text=query,
|
||||
vector_queries=[
|
||||
VectorizedQuery(
|
||||
vector=np.array(self.embed_query(query), dtype=np.float32).tolist(),
|
||||
vector=self.embed_query(query),
|
||||
k_nearest_neighbors=k,
|
||||
fields=FIELDS_CONTENT_VECTOR,
|
||||
)
|
||||
@ -1390,7 +1473,7 @@ class AzureSearch(VectorStore):
|
||||
search_text=query,
|
||||
vector_queries=[
|
||||
VectorizedQuery(
|
||||
vector=np.array(vector, dtype=np.float32).tolist(),
|
||||
vector=vector,
|
||||
k_nearest_neighbors=k,
|
||||
fields=FIELDS_CONTENT_VECTOR,
|
||||
)
|
||||
@ -1754,7 +1837,7 @@ async def _aresults_to_documents(
|
||||
|
||||
|
||||
async def _areorder_results_with_maximal_marginal_relevance(
|
||||
results: SearchItemPaged[Dict],
|
||||
results: AsyncSearchItemPaged[Dict],
|
||||
query_embedding: np.ndarray,
|
||||
lambda_mult: float = 0.5,
|
||||
k: int = 4,
|
||||
|
@ -231,7 +231,7 @@ class BigQueryVectorSearch(VectorStore):
|
||||
self._logger.debug("Vector index already exists.")
|
||||
self._have_index = True
|
||||
|
||||
def _create_index_in_background(self): # type: ignore[no-untyped-def]
|
||||
def _create_index_in_background(self) -> None:
|
||||
if self._have_index or self._creating_index:
|
||||
# Already have an index or in the process of creating one.
|
||||
return
|
||||
@ -240,7 +240,7 @@ class BigQueryVectorSearch(VectorStore):
|
||||
thread = Thread(target=self._create_index, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def _create_index(self): # type: ignore[no-untyped-def]
|
||||
def _create_index(self) -> None:
|
||||
from google.api_core.exceptions import ClientError
|
||||
|
||||
table = self.bq_client.get_table(self.vectors_table)
|
||||
|
@ -943,7 +943,7 @@ class DeepLake(VectorStore):
|
||||
return self.vectorstore.dataset
|
||||
|
||||
@classmethod
|
||||
def _validate_kwargs(cls, kwargs, method_name): # type: ignore[no-untyped-def]
|
||||
def _validate_kwargs(cls, kwargs: Any, method_name: str) -> None:
|
||||
if kwargs:
|
||||
valid_items = cls._get_valid_args(method_name)
|
||||
unsupported_items = cls._get_unsupported_items(kwargs, valid_items)
|
||||
@ -955,14 +955,14 @@ class DeepLake(VectorStore):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_valid_args(cls, method_name): # type: ignore[no-untyped-def]
|
||||
def _get_valid_args(cls, method_name: str) -> list[str]:
|
||||
if method_name == "search":
|
||||
return cls._valid_search_kwargs
|
||||
else:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _get_unsupported_items(kwargs, valid_items): # type: ignore[no-untyped-def]
|
||||
def _get_unsupported_items(kwargs: Any, valid_items: list[str]) -> Optional[str]:
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in valid_items}
|
||||
unsupported_items = None
|
||||
if kwargs:
|
||||
|
@ -15,7 +15,6 @@ from typing import (
|
||||
Optional,
|
||||
Pattern,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
@ -23,6 +22,7 @@ from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_community.vectorstores.utils import (
|
||||
DistanceStrategy,
|
||||
@ -149,7 +149,7 @@ class HanaDB(VectorStore):
|
||||
for column_name in self.specific_metadata_columns:
|
||||
self._check_column(self.table_name, column_name)
|
||||
|
||||
def _table_exists(self, table_name) -> bool: # type: ignore[no-untyped-def]
|
||||
def _table_exists(self, table_name: str) -> bool:
|
||||
sql_str = (
|
||||
"SELECT COUNT(*) FROM SYS.TABLES WHERE SCHEMA_NAME = CURRENT_SCHEMA"
|
||||
" AND TABLE_NAME = ?"
|
||||
@ -165,9 +165,13 @@ class HanaDB(VectorStore):
|
||||
cur.close()
|
||||
return False
|
||||
|
||||
def _check_column( # type: ignore[no-untyped-def]
|
||||
self, table_name, column_name, column_type=None, column_length=None
|
||||
):
|
||||
def _check_column(
|
||||
self,
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
column_type: Optional[list[str]] = None,
|
||||
column_length: Optional[int] = None,
|
||||
) -> None:
|
||||
sql_str = (
|
||||
"SELECT DATA_TYPE_NAME, LENGTH FROM SYS.TABLE_COLUMNS WHERE "
|
||||
"SCHEMA_NAME = CURRENT_SCHEMA "
|
||||
@ -406,8 +410,8 @@ class HanaDB(VectorStore):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_texts( # type: ignore[no-untyped-def, override]
|
||||
cls: Type[HanaDB],
|
||||
def from_texts( # type: ignore[override]
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
@ -420,7 +424,7 @@ class HanaDB(VectorStore):
|
||||
vector_column_length: int = default_vector_column_length,
|
||||
*,
|
||||
specific_metadata_columns: Optional[List[str]] = None,
|
||||
):
|
||||
) -> Self:
|
||||
"""Create a HanaDB instance from raw documents.
|
||||
This is a user-friendly interface that:
|
||||
1. Embeds documents.
|
||||
@ -512,12 +516,12 @@ class HanaDB(VectorStore):
|
||||
)
|
||||
order_str = f" order by CS {HANA_DISTANCE_FUNCTION[self.distance_strategy][1]}"
|
||||
where_str, query_tuple = self._create_where_by_filter(filter)
|
||||
query_tuple = (embedding_as_str,) + tuple(query_tuple)
|
||||
query_params = (embedding_as_str,) + tuple(query_tuple)
|
||||
sql_str = sql_str + where_str
|
||||
sql_str = sql_str + order_str
|
||||
try:
|
||||
cur = self.connection.cursor()
|
||||
cur.execute(sql_str, query_tuple)
|
||||
cur.execute(sql_str, query_params)
|
||||
if cur.has_result_set():
|
||||
rows = cur.fetchall()
|
||||
for row in rows:
|
||||
@ -567,15 +571,15 @@ class HanaDB(VectorStore):
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def _create_where_by_filter(self, filter): # type: ignore[no-untyped-def]
|
||||
query_tuple = []
|
||||
def _create_where_by_filter(self, filter: Optional[dict]) -> Tuple[str, list[Any]]:
|
||||
query_tuple: list[Any] = []
|
||||
where_str = ""
|
||||
if filter:
|
||||
where_str, query_tuple = self._process_filter_object(filter)
|
||||
where_str = " WHERE " + where_str
|
||||
return where_str, query_tuple
|
||||
|
||||
def _process_filter_object(self, filter): # type: ignore[no-untyped-def]
|
||||
def _process_filter_object(self, filter: Optional[dict]) -> Tuple[str, list[Any]]:
|
||||
query_tuple = []
|
||||
where_str = ""
|
||||
if filter:
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
@ -12,6 +12,9 @@ from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymilvus.orm.mutation import MutationResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MILVUS_CONNECTION = {
|
||||
@ -938,9 +941,9 @@ class Milvus(VectorStore):
|
||||
ret.append(documents[x])
|
||||
return ret
|
||||
|
||||
def delete( # type: ignore[no-untyped-def]
|
||||
self, ids: Optional[List[str]] = None, expr: Optional[str] = None, **kwargs: str
|
||||
):
|
||||
def delete(
|
||||
self, ids: Optional[List[str]] = None, expr: Optional[str] = None, **kwargs: Any
|
||||
) -> MutationResult:
|
||||
"""Delete by vector ID or boolean expression.
|
||||
Refer to [Milvus documentation](https://milvus.io/docs/delete_data.md)
|
||||
for notes and examples of expressions.
|
||||
|
@ -12,9 +12,11 @@ from typing import (
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
@ -750,7 +752,9 @@ class PGVector(VectorStore):
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyped-def]
|
||||
def _create_filter_clause_deprecated(
|
||||
self, key: str, value: dict[str, Any]
|
||||
) -> SQLColumnExpression:
|
||||
"""Deprecated functionality.
|
||||
|
||||
This is for backwards compatibility with the JSON based schema for metadata.
|
||||
@ -819,7 +823,7 @@ class PGVector(VectorStore):
|
||||
return filter_by_metadata
|
||||
|
||||
def _create_filter_clause_json_deprecated(
|
||||
self, filter: Any
|
||||
self, filter: Mapping[str, Union[str, dict[str, Any]]]
|
||||
) -> List[SQLColumnExpression]:
|
||||
"""Convert filters from IR to SQL clauses.
|
||||
|
||||
|
@ -2,12 +2,16 @@ import importlib
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from thirdai import neural_db as ndb
|
||||
|
||||
|
||||
class NeuralDBVectorStore(VectorStore):
|
||||
@ -25,10 +29,10 @@ class NeuralDBVectorStore(VectorStore):
|
||||
vectorstore = NeuralDBVectorStore(db=db)
|
||||
"""
|
||||
|
||||
def __init__(self, db: Any) -> None:
|
||||
def __init__(self, db: "ndb.NeuralDB") -> None:
|
||||
self.db = db
|
||||
|
||||
db: Any = None #: :meta private:
|
||||
db: "ndb.NeuralDB" = None #: :meta private:
|
||||
"""NeuralDB instance"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
@ -36,7 +40,7 @@ class NeuralDBVectorStore(VectorStore):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _verify_thirdai_library(thirdai_key: Optional[str] = None): # type: ignore[no-untyped-def]
|
||||
def _verify_thirdai_library(thirdai_key: Optional[str] = None) -> None:
|
||||
try:
|
||||
from thirdai import licensing
|
||||
|
||||
@ -50,11 +54,11 @@ class NeuralDBVectorStore(VectorStore):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_scratch( # type: ignore[no-untyped-def, no-untyped-def]
|
||||
def from_scratch(
|
||||
cls,
|
||||
thirdai_key: Optional[str] = None,
|
||||
**model_kwargs,
|
||||
):
|
||||
**model_kwargs: Any,
|
||||
) -> Self:
|
||||
"""
|
||||
Create a NeuralDBVectorStore from scratch.
|
||||
|
||||
@ -84,11 +88,11 @@ class NeuralDBVectorStore(VectorStore):
|
||||
return cls(db=ndb.NeuralDB(**model_kwargs)) # type: ignore[call-arg]
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint( # type: ignore[no-untyped-def]
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
checkpoint: Union[str, Path],
|
||||
thirdai_key: Optional[str] = None,
|
||||
):
|
||||
) -> Self:
|
||||
"""
|
||||
Create a NeuralDBVectorStore with a base model from a saved checkpoint
|
||||
|
||||
@ -163,13 +167,13 @@ class NeuralDBVectorStore(VectorStore):
|
||||
offset = self.db._savable_state.documents.get_source_by_id(source_id)[1]
|
||||
return [str(offset + i) for i in range(len(texts))] # type: ignore[arg-type]
|
||||
|
||||
def insert( # type: ignore[no-untyped-def, no-untyped-def]
|
||||
def insert(
|
||||
self,
|
||||
sources: List[Any],
|
||||
sources: list[Union[str, "ndb.Document"]],
|
||||
train: bool = True,
|
||||
fast_mode: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Inserts files / document sources into the vectorstore.
|
||||
|
||||
Args:
|
||||
@ -180,14 +184,16 @@ class NeuralDBVectorStore(VectorStore):
|
||||
Defaults to True.
|
||||
"""
|
||||
sources = self._preprocess_sources(sources)
|
||||
self.db.insert(
|
||||
return self.db.insert(
|
||||
sources=sources,
|
||||
train=train,
|
||||
fast_approximation=fast_mode,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _preprocess_sources(self, sources): # type: ignore[no-untyped-def]
|
||||
def _preprocess_sources(
|
||||
self, sources: list[Union[str, "ndb.Document"]]
|
||||
) -> list["ndb.Document"]:
|
||||
"""Checks if the provided sources are string paths. If they are, convert
|
||||
to NeuralDB document objects.
|
||||
|
||||
@ -219,7 +225,7 @@ class NeuralDBVectorStore(VectorStore):
|
||||
)
|
||||
return preprocessed_sources
|
||||
|
||||
def upvote(self, query: str, document_id: Union[int, str]): # type: ignore[no-untyped-def]
|
||||
def upvote(self, query: str, document_id: Union[int, str]) -> None:
|
||||
"""The vectorstore upweights the score of a document for a specific query.
|
||||
This is useful for fine-tuning the vectorstore to user behavior.
|
||||
|
||||
@ -229,7 +235,7 @@ class NeuralDBVectorStore(VectorStore):
|
||||
"""
|
||||
self.db.text_to_result(query, int(document_id))
|
||||
|
||||
def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]): # type: ignore[no-untyped-def]
|
||||
def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]) -> None:
|
||||
"""Given a batch of (query, document id) pairs, the vectorstore upweights
|
||||
the scores of the document for the corresponding queries.
|
||||
This is useful for fine-tuning the vectorstore to user behavior.
|
||||
@ -242,7 +248,7 @@ class NeuralDBVectorStore(VectorStore):
|
||||
[(query, int(doc_id)) for query, doc_id in query_id_pairs]
|
||||
)
|
||||
|
||||
def associate(self, source: str, target: str): # type: ignore[no-untyped-def]
|
||||
def associate(self, source: str, target: str) -> None:
|
||||
"""The vectorstore associates a source phrase with a target phrase.
|
||||
When the vectorstore sees the source phrase, it will also consider results
|
||||
that are relevant to the target phrase.
|
||||
@ -253,7 +259,7 @@ class NeuralDBVectorStore(VectorStore):
|
||||
"""
|
||||
self.db.associate(source, target)
|
||||
|
||||
def associate_batch(self, text_pairs: List[Tuple[str, str]]): # type: ignore[no-untyped-def]
|
||||
def associate_batch(self, text_pairs: List[Tuple[str, str]]) -> None:
|
||||
"""Given a batch of (source, target) pairs, the vectorstore associates
|
||||
each source phrase with the corresponding target phrase.
|
||||
|
||||
@ -291,7 +297,7 @@ class NeuralDBVectorStore(VectorStore):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error while retrieving documents: {e}") from e
|
||||
|
||||
def save(self, path: str): # type: ignore[no-untyped-def]
|
||||
def save(self, path: str) -> None:
|
||||
"""Saves a NeuralDB instance to disk. Can be loaded into memory by
|
||||
calling NeuralDB.from_checkpoint(path)
|
||||
|
||||
@ -325,10 +331,10 @@ class NeuralDBClientVectorStore(VectorStore):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, db: Any) -> None:
|
||||
def __init__(self, db: "ndb.NeuralDBClient") -> None:
|
||||
self.db = db
|
||||
|
||||
db: Any = None #: :meta private:
|
||||
db: "ndb.NeuralDBClient" = None #: :meta private:
|
||||
"""NeuralDB Client instance"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
@ -362,7 +368,7 @@ class NeuralDBClientVectorStore(VectorStore):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error while retrieving documents: {e}") from e
|
||||
|
||||
def insert(self, documents: List[Dict[str, Any]]): # type: ignore[no-untyped-def, no-untyped-def]
|
||||
def insert(self, documents: List[Dict[str, Any]]) -> Any:
|
||||
"""
|
||||
Inserts documents into the VectorStore and return the corresponding Sources.
|
||||
|
||||
@ -446,7 +452,7 @@ class NeuralDBClientVectorStore(VectorStore):
|
||||
"""
|
||||
return self.db.insert(documents)
|
||||
|
||||
def remove_documents(self, source_ids: List[str]): # type: ignore[no-untyped-def]
|
||||
def remove_documents(self, source_ids: list[str]) -> None:
|
||||
"""
|
||||
Deletes documents from the VectorStore using source ids.
|
||||
|
||||
|
@ -8,6 +8,7 @@ import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
@ -31,7 +32,14 @@ class VikingDBConfig(object):
|
||||
scheme(str):http or https, defaulting to http.
|
||||
"""
|
||||
|
||||
def __init__(self, host="host", region="region", ak="ak", sk="sk", scheme="http"): # type: ignore[no-untyped-def]
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "host",
|
||||
region: str = "region",
|
||||
ak: str = "ak",
|
||||
sk: str = "sk",
|
||||
scheme: str = "http",
|
||||
) -> None:
|
||||
self.host = host
|
||||
self.region = region
|
||||
self.ak = ak
|
||||
@ -397,7 +405,7 @@ class VikingDB(VectorStore):
|
||||
self.collection.delete_data(ids) # type: ignore[union-attr]
|
||||
|
||||
@classmethod
|
||||
def from_texts( # type: ignore[no-untyped-def, override]
|
||||
def from_texts( # type: ignore[override]
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
@ -407,7 +415,7 @@ class VikingDB(VectorStore):
|
||||
index_params: Optional[dict] = None,
|
||||
drop_old: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
) -> Self:
|
||||
"""Create a collection, indexes it and insert data."""
|
||||
if connection_args is None:
|
||||
raise Exception("VikingDBConfig does not exists")
|
||||
|
@ -57,7 +57,7 @@ def test_add_messages() -> None:
|
||||
assert len(message_store_another.messages) == 0
|
||||
|
||||
|
||||
def test_tidb_recent_chat_message(): # type: ignore[no-untyped-def]
|
||||
def test_tidb_recent_chat_message() -> None:
|
||||
"""Test the TiDBChatMessageHistory with earliest_time parameter."""
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
@ -5,13 +5,16 @@ You can get a list of models from the bedrock client by running 'bedrock_models(
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import AsyncCallbackHandler
|
||||
|
||||
from langchain_community.llms.bedrock import Bedrock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botocore.client import BaseClient
|
||||
|
||||
# this is the guardrails id for the model you want to test
|
||||
GUARDRAILS_ID = os.environ.get("GUARDRAILS_ID", "7jarelix77")
|
||||
# this is the guardrails version for the model you want to test
|
||||
@ -37,12 +40,12 @@ class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
|
||||
if reason == "GUARDRAIL_INTERVENED":
|
||||
self.guardrails_intervened = True
|
||||
|
||||
def get_response(self): # type: ignore[no-untyped-def]
|
||||
def get_response(self) -> bool:
|
||||
return self.guardrails_intervened
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bedrock_runtime_client(): # type: ignore[no-untyped-def]
|
||||
def bedrock_runtime_client() -> "BaseClient":
|
||||
import boto3
|
||||
|
||||
try:
|
||||
@ -56,7 +59,7 @@ def bedrock_runtime_client(): # type: ignore[no-untyped-def]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bedrock_client(): # type: ignore[no-untyped-def]
|
||||
def bedrock_client() -> "BaseClient":
|
||||
import boto3
|
||||
|
||||
try:
|
||||
@ -70,7 +73,7 @@ def bedrock_client(): # type: ignore[no-untyped-def]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bedrock_models(bedrock_client): # type: ignore[no-untyped-def]
|
||||
def bedrock_models(bedrock_client: "BaseClient") -> dict:
|
||||
"""List bedrock models."""
|
||||
response = bedrock_client.list_foundation_models().get("modelSummaries")
|
||||
models = {}
|
||||
@ -79,7 +82,9 @@ def bedrock_models(bedrock_client): # type: ignore[no-untyped-def]
|
||||
return models
|
||||
|
||||
|
||||
def test_claude_instant_v1(bedrock_runtime_client, bedrock_models): # type: ignore[no-untyped-def]
|
||||
def test_claude_instant_v1(
|
||||
bedrock_runtime_client: "BaseClient", bedrock_models: dict
|
||||
) -> None:
|
||||
try:
|
||||
llm = Bedrock(
|
||||
model_id="anthropic.claude-instant-v1",
|
||||
@ -92,9 +97,9 @@ def test_claude_instant_v1(bedrock_runtime_client, bedrock_models): # type: ign
|
||||
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
|
||||
|
||||
|
||||
def test_amazon_bedrock_guardrails_no_intervention_for_valid_query( # type: ignore[no-untyped-def]
|
||||
bedrock_runtime_client, bedrock_models
|
||||
):
|
||||
def test_amazon_bedrock_guardrails_no_intervention_for_valid_query(
|
||||
bedrock_runtime_client: "BaseClient", bedrock_models: dict
|
||||
) -> None:
|
||||
try:
|
||||
llm = Bedrock(
|
||||
model_id="anthropic.claude-instant-v1",
|
||||
@ -112,9 +117,9 @@ def test_amazon_bedrock_guardrails_no_intervention_for_valid_query( # type: ign
|
||||
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
|
||||
|
||||
|
||||
def test_amazon_bedrock_guardrails_intervention_for_invalid_query( # type: ignore[no-untyped-def]
|
||||
bedrock_runtime_client, bedrock_models
|
||||
):
|
||||
def test_amazon_bedrock_guardrails_intervention_for_invalid_query(
|
||||
bedrock_runtime_client: "BaseClient", bedrock_models: dict
|
||||
) -> None:
|
||||
try:
|
||||
handler = BedrockAsyncCallbackHandler()
|
||||
llm = Bedrock(
|
||||
|
@ -1,13 +1,13 @@
|
||||
"""Unit test for Google Trends API Wrapper."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langchain_community.utilities.google_trends import GoogleTrendsAPIWrapper
|
||||
|
||||
|
||||
@patch("serpapi.SerpApiClient.get_json")
|
||||
def test_unexpected_response(mocked_serpapiclient): # type: ignore[no-untyped-def]
|
||||
def test_unexpected_response(mocked_serpapiclient: MagicMock) -> None:
|
||||
os.environ["SERPAPI_API_KEY"] = "123abcd"
|
||||
resp = {
|
||||
"search_metadata": {
|
||||
|
@ -15,7 +15,7 @@ def qdrant_is_not_running() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def assert_documents_equals(actual: List[Document], expected: List[Document]): # type: ignore[no-untyped-def]
|
||||
def assert_documents_equals(actual: List[Document], expected: List[Document]) -> None:
|
||||
assert len(actual) == len(expected)
|
||||
|
||||
for actual_doc, expected_doc in zip(actual, expected):
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Test Deep Lake functionality."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
from pytest import FixtureRequest
|
||||
@ -9,7 +11,7 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deeplake_datastore() -> DeepLake: # type: ignore[misc]
|
||||
def deeplake_datastore() -> Iterator[DeepLake]:
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = DeepLake.from_texts(
|
||||
@ -53,7 +55,7 @@ def test_deeplake_with_metadatas() -> None:
|
||||
assert output == [Document(page_content="foo", metadata={"page": "0"})]
|
||||
|
||||
|
||||
def test_deeplake_with_persistence(deeplake_datastore) -> None: # type: ignore[no-untyped-def]
|
||||
def test_deeplake_with_persistence(deeplake_datastore: DeepLake) -> None:
|
||||
"""Test end to end construction and search, with persistence."""
|
||||
output = deeplake_datastore.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo", metadata={"page": "0"})]
|
||||
@ -73,7 +75,7 @@ def test_deeplake_with_persistence(deeplake_datastore) -> None: # type: ignore[
|
||||
# Or on program exit
|
||||
|
||||
|
||||
def test_deeplake_overwrite_flag(deeplake_datastore) -> None: # type: ignore[no-untyped-def]
|
||||
def test_deeplake_overwrite_flag(deeplake_datastore: DeepLake) -> None:
|
||||
"""Test overwrite behavior"""
|
||||
dataset_path = deeplake_datastore.vectorstore.dataset_handler.path
|
||||
|
||||
@ -109,7 +111,7 @@ def test_deeplake_overwrite_flag(deeplake_datastore) -> None: # type: ignore[no
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
|
||||
|
||||
def test_similarity_search(deeplake_datastore) -> None: # type: ignore[no-untyped-def]
|
||||
def test_similarity_search(deeplake_datastore: DeepLake) -> None:
|
||||
"""Test similarity search."""
|
||||
distance_metric = "cos"
|
||||
output = deeplake_datastore.similarity_search(
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import os
|
||||
import random
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
@ -29,7 +30,6 @@ TYPE_4B_FILTERING_TEST_CASES = [
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
try:
|
||||
from hdbcli import dbapi
|
||||
|
||||
@ -56,15 +56,15 @@ embedding = NormalizedFakeEmbeddings()
|
||||
|
||||
|
||||
class ConfigData:
|
||||
def __init__(self): # type: ignore[no-untyped-def]
|
||||
self.conn = None
|
||||
self.schema_name = ""
|
||||
def __init__(self) -> None:
|
||||
self.conn: dbapi.Connection = None
|
||||
self.schema_name: str = ""
|
||||
|
||||
|
||||
test_setup = ConfigData()
|
||||
|
||||
|
||||
def generateSchemaName(cursor): # type: ignore[no-untyped-def]
|
||||
def generateSchemaName(cursor: "dbapi.Cursor") -> str:
|
||||
# return "Langchain"
|
||||
cursor.execute(
|
||||
"SELECT REPLACE(CURRENT_UTCDATE, '-', '') || '_' || BINTOHEX(SYSUUID) FROM "
|
||||
@ -78,7 +78,7 @@ def generateSchemaName(cursor): # type: ignore[no-untyped-def]
|
||||
return f"VEC_{uid}"
|
||||
|
||||
|
||||
def setup_module(module): # type: ignore[no-untyped-def]
|
||||
def setup_module(module: ModuleType) -> None:
|
||||
test_setup.conn = dbapi.connect(
|
||||
address=os.environ.get("HANA_DB_ADDRESS"),
|
||||
port=os.environ.get("HANA_DB_PORT"),
|
||||
@ -101,7 +101,7 @@ def setup_module(module): # type: ignore[no-untyped-def]
|
||||
cur.close()
|
||||
|
||||
|
||||
def teardown_module(module): # type: ignore[no-untyped-def]
|
||||
def teardown_module(module: ModuleType) -> None:
|
||||
# return
|
||||
try:
|
||||
cur = test_setup.conn.cursor()
|
||||
@ -119,17 +119,17 @@ def texts() -> List[str]:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metadatas() -> List[str]:
|
||||
def metadatas() -> List[dict[str, Any]]:
|
||||
return [
|
||||
{"start": 0, "end": 100, "quality": "good", "ready": True}, # type: ignore[list-item]
|
||||
{"start": 100, "end": 200, "quality": "bad", "ready": False}, # type: ignore[list-item]
|
||||
{"start": 200, "end": 300, "quality": "ugly", "ready": True}, # type: ignore[list-item]
|
||||
{"start": 200, "quality": "ugly", "ready": True, "Owner": "Steve"}, # type: ignore[list-item]
|
||||
{"start": 300, "quality": "ugly", "Owner": "Steve"}, # type: ignore[list-item]
|
||||
{"start": 0, "end": 100, "quality": "good", "ready": True},
|
||||
{"start": 100, "end": 200, "quality": "bad", "ready": False},
|
||||
{"start": 200, "end": 300, "quality": "ugly", "ready": True},
|
||||
{"start": 200, "quality": "ugly", "ready": True, "Owner": "Steve"},
|
||||
{"start": 300, "quality": "ugly", "Owner": "Steve"},
|
||||
]
|
||||
|
||||
|
||||
def drop_table(connection, table_name): # type: ignore[no-untyped-def]
|
||||
def drop_table(connection: "dbapi.Connection", table_name: str) -> None:
|
||||
try:
|
||||
cur = connection.cursor()
|
||||
sql_str = f"DROP TABLE {table_name}"
|
||||
@ -279,7 +279,7 @@ def test_hanavector_non_existing_table_fixed_vector_length() -> None:
|
||||
|
||||
assert vectordb._table_exists(table_name)
|
||||
vectordb._check_column(
|
||||
table_name, vector_column, "REAL_VECTOR", vector_column_length
|
||||
table_name, vector_column, ["REAL_VECTOR"], vector_column_length
|
||||
)
|
||||
|
||||
|
||||
|
@ -32,7 +32,7 @@ def fix_distance_precision(
|
||||
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def __init__(self): # type: ignore[no-untyped-def]
|
||||
def __init__(self) -> None:
|
||||
super(FakeEmbeddingsWithAdaDimension, self).__init__(size=ADA_TOKEN_COUNT)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
|
@ -1,13 +1,15 @@
|
||||
import os
|
||||
import shutil
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.vectorstores import NeuralDBVectorStore
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_csv(): # type: ignore[no-untyped-def]
|
||||
def test_csv() -> Iterator[str]:
|
||||
csv = "thirdai-test.csv"
|
||||
with open(csv, "w") as o:
|
||||
o.write("column_1,column_2\n")
|
||||
@ -16,13 +18,13 @@ def test_csv(): # type: ignore[no-untyped-def]
|
||||
os.remove(csv)
|
||||
|
||||
|
||||
def assert_result_correctness(documents): # type: ignore[no-untyped-def]
|
||||
def assert_result_correctness(documents: list[Document]) -> None:
|
||||
assert len(documents) == 1
|
||||
assert documents[0].page_content == "column_1: column one\n\ncolumn_2: column two"
|
||||
|
||||
|
||||
@pytest.mark.requires("thirdai[neural_db]")
|
||||
def test_neuraldb_retriever_from_scratch(test_csv): # type: ignore[no-untyped-def]
|
||||
def test_neuraldb_retriever_from_scratch(test_csv: str) -> None:
|
||||
retriever = NeuralDBVectorStore.from_scratch()
|
||||
retriever.insert([test_csv])
|
||||
documents = retriever.similarity_search("column")
|
||||
@ -30,7 +32,7 @@ def test_neuraldb_retriever_from_scratch(test_csv): # type: ignore[no-untyped-d
|
||||
|
||||
|
||||
@pytest.mark.requires("thirdai[neural_db]")
|
||||
def test_neuraldb_retriever_from_checkpoint(test_csv): # type: ignore[no-untyped-def]
|
||||
def test_neuraldb_retriever_from_checkpoint(test_csv: str) -> None:
|
||||
checkpoint = "thirdai-test-save.ndb"
|
||||
if os.path.exists(checkpoint):
|
||||
shutil.rmtree(checkpoint)
|
||||
@ -47,7 +49,7 @@ def test_neuraldb_retriever_from_checkpoint(test_csv): # type: ignore[no-untype
|
||||
|
||||
|
||||
@pytest.mark.requires("thirdai[neural_db]")
|
||||
def test_neuraldb_retriever_other_methods(test_csv): # type: ignore[no-untyped-def]
|
||||
def test_neuraldb_retriever_other_methods(test_csv: str) -> None:
|
||||
retriever = NeuralDBVectorStore.from_scratch()
|
||||
retriever.insert([test_csv])
|
||||
# Make sure they don't throw an error.
|
||||
|
@ -268,7 +268,7 @@ def vectara3() -> Iterable[Vectara]:
|
||||
vectara3.delete(doc_ids)
|
||||
|
||||
|
||||
def test_vectara_with_langchain_mmr(vectara3: Vectara) -> None: # type: ignore[no-untyped-def]
|
||||
def test_vectara_with_langchain_mmr(vectara3: Vectara) -> None:
|
||||
# test max marginal relevance
|
||||
output1 = vectara3.max_marginal_relevance_search(
|
||||
"generative AI",
|
||||
@ -299,7 +299,7 @@ def test_vectara_with_langchain_mmr(vectara3: Vectara) -> None: # type: ignore[
|
||||
)
|
||||
|
||||
|
||||
def test_vectara_rerankers(vectara3: Vectara) -> None: # type: ignore[no-untyped-def]
|
||||
def test_vectara_rerankers(vectara3: Vectara) -> None:
|
||||
# test Vectara multi-lingual reranker
|
||||
summary_config = SummaryConfig(is_enabled=True, max_results=7, response_lang="eng")
|
||||
rerank_config = RerankConfig(reranker="rerank_multilingual_v1", rerank_k=50)
|
||||
@ -375,7 +375,7 @@ def test_vectara_rerankers(vectara3: Vectara) -> None: # type: ignore[no-untype
|
||||
assert len(output2) > 0
|
||||
|
||||
|
||||
def test_vectara_with_summary(vectara3) -> None: # type: ignore[no-untyped-def]
|
||||
def test_vectara_with_summary(vectara3: Vectara) -> None:
|
||||
"""Test vectara summary."""
|
||||
# test summarization
|
||||
num_results = 10
|
||||
|
@ -43,7 +43,7 @@ def test_edenai_messages_formatting(messages: List[BaseMessage], expected: str)
|
||||
("role", "role_response"),
|
||||
[("ai", "assistant"), ("human", "user"), ("chat", "user")],
|
||||
)
|
||||
def test_edenai_message_role(role: str, role_response) -> None: # type: ignore[no-untyped-def]
|
||||
def test_edenai_message_role(role: str, role_response: str) -> None:
|
||||
role = _message_role(role)
|
||||
assert role == role_response
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Generator, cast
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -20,6 +20,9 @@ from langchain_community.chat_models.naver import (
|
||||
_convert_naver_chat_message_to_message,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from httpx_sse import ServerSentEvent
|
||||
|
||||
os.environ["NCP_CLOVASTUDIO_API_KEY"] = "test_api_key"
|
||||
os.environ["NCP_APIGW_API_KEY"] = "test_gw_key"
|
||||
|
||||
@ -131,7 +134,7 @@ async def test_naver_ainvoke(mock_chat_completion_response: dict) -> None:
|
||||
assert completed
|
||||
|
||||
|
||||
def _make_completion_response_from_token(token: str): # type: ignore[no-untyped-def]
|
||||
def _make_completion_response_from_token(token: str) -> "ServerSentEvent":
|
||||
from httpx_sse import ServerSentEvent
|
||||
|
||||
return ServerSentEvent(
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Test OCI Generative AI LLM service"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -10,7 +11,7 @@ from langchain_community.chat_models.oci_generative_ai import ChatOCIGenAI
|
||||
|
||||
|
||||
class MockResponseDict(dict):
|
||||
def __getattr__(self, val): # type: ignore[no-untyped-def]
|
||||
def __getattr__(self, val: str) -> Any:
|
||||
return self[val]
|
||||
|
||||
|
||||
@ -29,7 +30,7 @@ def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
|
||||
|
||||
provider = model_id.split(".")[0].lower()
|
||||
|
||||
def mocked_response(*args): # type: ignore[no-untyped-def]
|
||||
def mocked_response(*args: Any) -> MockResponseDict:
|
||||
response_text = "Assistant chat reply."
|
||||
response = None
|
||||
if provider == "cohere":
|
||||
@ -102,6 +103,8 @@ def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider {provider}")
|
||||
return response
|
||||
|
||||
monkeypatch.setattr(llm.client, "chat", mocked_response)
|
||||
|
@ -20,7 +20,7 @@ class MockEmbed4All(MagicMock):
|
||||
n_threads: Optional[int] = None,
|
||||
device: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
): # type: ignore[no-untyped-def]
|
||||
):
|
||||
assert model_name == _GPT4ALL_MODEL_NAME
|
||||
|
||||
|
||||
|
@ -45,14 +45,14 @@ class GradientEmbeddingsModel(MagicMock):
|
||||
output.embeddings = embeddings
|
||||
return output
|
||||
|
||||
async def aembed(self, *args) -> Any: # type: ignore[no-untyped-def]
|
||||
async def aembed(self, *args: Any) -> Any:
|
||||
return self.embed(*args)
|
||||
|
||||
|
||||
class MockGradient(MagicMock):
|
||||
"""Mock Gradient package."""
|
||||
|
||||
def __init__(self, access_token: str, workspace_id, host): # type: ignore[no-untyped-def]
|
||||
def __init__(self, access_token: str, workspace_id: str, host: str) -> None:
|
||||
assert access_token == _GRADIENT_SECRET
|
||||
assert workspace_id == _GRADIENT_WORKSPACE_ID
|
||||
assert host == _GRADIENT_BASE_URL
|
||||
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@ -23,7 +24,9 @@ def test_embed_documents(monkeypatch: MonkeyPatch) -> None:
|
||||
base_url="http://llamafile-host:8080",
|
||||
)
|
||||
|
||||
def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def]
|
||||
def mock_post(
|
||||
url: str, headers: dict[str, str], json: dict[str, Any], timeout: float
|
||||
) -> requests.Response:
|
||||
assert url == "http://llamafile-host:8080/embedding"
|
||||
assert headers == {
|
||||
"Content-Type": "application/json",
|
||||
@ -50,7 +53,9 @@ def test_embed_query(monkeypatch: MonkeyPatch) -> None:
|
||||
base_url="http://llamafile-host:8080",
|
||||
)
|
||||
|
||||
def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def]
|
||||
def mock_post(
|
||||
url: str, headers: dict[str, str], json: dict[str, Any], timeout: float
|
||||
) -> None:
|
||||
assert url == "http://llamafile-host:8080/embedding"
|
||||
assert headers == {
|
||||
"Content-Type": "application/json",
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Test OCI Generative AI embedding service."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -7,9 +8,12 @@ from pytest import MonkeyPatch
|
||||
|
||||
from langchain_community.embeddings import OCIGenAIEmbeddings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from oci.generative_ai_inference.models import EmbedTextDetails
|
||||
|
||||
|
||||
class MockResponseDict(dict):
|
||||
def __getattr__(self, val): # type: ignore[no-untyped-def]
|
||||
def __getattr__(self, val: str) -> Any:
|
||||
return self[val]
|
||||
|
||||
|
||||
@ -26,7 +30,7 @@ def test_embedding_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
|
||||
client=oci_gen_ai_client,
|
||||
)
|
||||
|
||||
def mocked_response(invocation_obj): # type: ignore[no-untyped-def]
|
||||
def mocked_response(invocation_obj: "EmbedTextDetails") -> MockResponseDict:
|
||||
docs = invocation_obj.inputs
|
||||
|
||||
embeddings = []
|
||||
|
@ -1,14 +1,14 @@
|
||||
from langchain_community.graphs.neo4j_graph import value_sanitize
|
||||
|
||||
|
||||
def test_value_sanitize_with_small_list(): # type: ignore[no-untyped-def]
|
||||
def test_value_sanitize_with_small_list() -> None:
|
||||
small_list = list(range(15)) # list size > LIST_LIMIT
|
||||
input_dict = {"key1": "value1", "small_list": small_list}
|
||||
expected_output = {"key1": "value1", "small_list": small_list}
|
||||
assert value_sanitize(input_dict) == expected_output
|
||||
|
||||
|
||||
def test_value_sanitize_with_oversized_list(): # type: ignore[no-untyped-def]
|
||||
def test_value_sanitize_with_oversized_list() -> None:
|
||||
oversized_list = list(range(150)) # list size > LIST_LIMIT
|
||||
input_dict = {"key1": "value1", "oversized_list": oversized_list}
|
||||
expected_output = {
|
||||
@ -18,21 +18,21 @@ def test_value_sanitize_with_oversized_list(): # type: ignore[no-untyped-def]
|
||||
assert value_sanitize(input_dict) == expected_output
|
||||
|
||||
|
||||
def test_value_sanitize_with_nested_oversized_list(): # type: ignore[no-untyped-def]
|
||||
def test_value_sanitize_with_nested_oversized_list() -> None:
|
||||
oversized_list = list(range(150)) # list size > LIST_LIMIT
|
||||
input_dict = {"key1": "value1", "oversized_list": {"key": oversized_list}}
|
||||
expected_output = {"key1": "value1", "oversized_list": {}}
|
||||
assert value_sanitize(input_dict) == expected_output
|
||||
|
||||
|
||||
def test_value_sanitize_with_dict_in_list(): # type: ignore[no-untyped-def]
|
||||
def test_value_sanitize_with_dict_in_list() -> None:
|
||||
oversized_list = list(range(150)) # list size > LIST_LIMIT
|
||||
input_dict = {"key1": "value1", "oversized_list": [1, 2, {"key": oversized_list}]}
|
||||
expected_output = {"key1": "value1", "oversized_list": [1, 2, {}]}
|
||||
assert value_sanitize(input_dict) == expected_output
|
||||
|
||||
|
||||
def test_value_sanitize_with_dict_in_nested_list(): # type: ignore[no-untyped-def]
|
||||
def test_value_sanitize_with_dict_in_nested_list() -> None:
|
||||
input_dict = {
|
||||
"key1": "value1",
|
||||
"deeply_nested_lists": [[[[{"final_nested_key": list(range(200))}]]]],
|
||||
|
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections import deque
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
@ -39,7 +39,7 @@ def mock_response() -> requests.Response:
|
||||
return response
|
||||
|
||||
|
||||
def mock_response_stream(): # type: ignore[no-untyped-def]
|
||||
def mock_response_stream() -> requests.Response:
|
||||
mock_response = deque(
|
||||
[
|
||||
b'data: {"content":"the","multimodal":false,"slot_id":0,"stop":false}\n\n',
|
||||
@ -48,7 +48,7 @@ def mock_response_stream(): # type: ignore[no-untyped-def]
|
||||
)
|
||||
|
||||
class MockRaw:
|
||||
def read(self, chunk_size): # type: ignore[no-untyped-def]
|
||||
def read(self, chunk_size: int) -> Optional[bytes]:
|
||||
try:
|
||||
return mock_response.popleft()
|
||||
except IndexError:
|
||||
@ -68,7 +68,13 @@ def test_call(monkeypatch: MonkeyPatch) -> None:
|
||||
base_url="http://llamafile-host:8080",
|
||||
)
|
||||
|
||||
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
|
||||
def mock_post(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
json: dict[str, Any],
|
||||
stream: bool,
|
||||
timeout: Optional[float],
|
||||
) -> requests.Response:
|
||||
assert url == "http://llamafile-host:8080/completion"
|
||||
assert headers == {
|
||||
"Content-Type": "application/json",
|
||||
@ -94,7 +100,13 @@ def test_call_with_kwargs(monkeypatch: MonkeyPatch) -> None:
|
||||
base_url="http://llamafile-host:8080",
|
||||
)
|
||||
|
||||
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
|
||||
def mock_post(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
json: dict[str, Any],
|
||||
stream: bool,
|
||||
timeout: Optional[float],
|
||||
) -> requests.Response:
|
||||
assert url == "http://llamafile-host:8080/completion"
|
||||
assert headers == {
|
||||
"Content-Type": "application/json",
|
||||
@ -138,7 +150,13 @@ def test_streaming(monkeypatch: MonkeyPatch) -> None:
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
|
||||
def mock_post(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
json: dict[str, Any],
|
||||
stream: bool,
|
||||
timeout: Optional[float],
|
||||
) -> requests.Response:
|
||||
assert url == "http://llamafile-hostname:8080/completion"
|
||||
assert headers == {
|
||||
"Content-Type": "application/json",
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Test OCI Generative AI LLM service"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -9,7 +10,7 @@ from langchain_community.llms.oci_generative_ai import OCIGenAI
|
||||
|
||||
|
||||
class MockResponseDict(dict):
|
||||
def __getattr__(self, val): # type: ignore[no-untyped-def]
|
||||
def __getattr__(self, val: Any) -> Any:
|
||||
return self[val]
|
||||
|
||||
|
||||
@ -29,7 +30,7 @@ def test_llm_complete(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
|
||||
|
||||
provider = model_id.split(".")[0].lower()
|
||||
|
||||
def mocked_response(*args): # type: ignore[no-untyped-def]
|
||||
def mocked_response(*args: Any) -> MockResponseDict:
|
||||
response_text = "This is the completion."
|
||||
|
||||
if provider == "cohere":
|
||||
@ -75,6 +76,7 @@ def test_llm_complete(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
|
||||
),
|
||||
}
|
||||
)
|
||||
raise ValueError("Unsupported provider")
|
||||
|
||||
monkeypatch.setattr(llm.client, "generate_text", mocked_response)
|
||||
output = llm.invoke("This is a prompt.", temperature=0.2)
|
||||
|
@ -1,14 +1,16 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from langchain_community.llms.ollama import Ollama
|
||||
|
||||
|
||||
def mock_response_stream(): # type: ignore[no-untyped-def]
|
||||
def mock_response_stream() -> requests.Response:
|
||||
mock_response = [b'{ "response": "Response chunk 1" }']
|
||||
|
||||
class MockRaw:
|
||||
def read(self, chunk_size): # type: ignore[no-untyped-def]
|
||||
def read(self, chunk_size: int) -> Optional[bytes]:
|
||||
try:
|
||||
return mock_response.pop()
|
||||
except IndexError:
|
||||
@ -31,7 +33,14 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
|
||||
def mock_post(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
json: dict[str, Any],
|
||||
stream: bool,
|
||||
timeout: Optional[float],
|
||||
auth: tuple[str, str],
|
||||
) -> requests.Response:
|
||||
assert url == "https://ollama-hostname:8000/api/generate"
|
||||
assert headers == {
|
||||
"Content-Type": "application/json",
|
||||
@ -57,7 +66,14 @@ def test_pass_auth_if_provided(monkeypatch: MonkeyPatch) -> None:
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
|
||||
def mock_post(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
json: dict[str, Any],
|
||||
stream: bool,
|
||||
timeout: Optional[float],
|
||||
auth: tuple[str, str],
|
||||
) -> requests.Response:
|
||||
assert url == "https://ollama-hostname:8000/api/generate"
|
||||
assert headers == {
|
||||
"Content-Type": "application/json",
|
||||
@ -77,7 +93,14 @@ def test_pass_auth_if_provided(monkeypatch: MonkeyPatch) -> None:
|
||||
def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
|
||||
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
||||
|
||||
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
|
||||
def mock_post(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
json: dict[str, Any],
|
||||
stream: bool,
|
||||
timeout: Optional[float],
|
||||
auth: tuple[str, str],
|
||||
) -> requests.Response:
|
||||
assert url == "https://ollama-hostname:8000/api/generate"
|
||||
assert headers == {
|
||||
"Content-Type": "application/json",
|
||||
@ -97,7 +120,14 @@ def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None:
|
||||
"""Test that top level params are sent to the endpoint as top level params"""
|
||||
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
||||
|
||||
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
|
||||
def mock_post(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
json: dict[str, Any],
|
||||
stream: bool,
|
||||
timeout: Optional[float],
|
||||
auth: tuple[str, str],
|
||||
) -> requests.Response:
|
||||
assert url == "https://ollama-hostname:8000/api/generate"
|
||||
assert headers == {
|
||||
"Content-Type": "application/json",
|
||||
@ -145,7 +175,14 @@ def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None:
|
||||
"""
|
||||
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
||||
|
||||
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
|
||||
def mock_post(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
json: dict[str, Any],
|
||||
stream: bool,
|
||||
timeout: Optional[float],
|
||||
auth: tuple[str, str],
|
||||
) -> requests.Response:
|
||||
assert url == "https://ollama-hostname:8000/api/generate"
|
||||
assert headers == {
|
||||
"Content-Type": "application/json",
|
||||
@ -194,7 +231,14 @@ def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None:
|
||||
"""
|
||||
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
||||
|
||||
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
|
||||
def mock_post(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
json: dict[str, Any],
|
||||
stream: bool,
|
||||
timeout: Optional[float],
|
||||
auth: tuple[str, str],
|
||||
) -> requests.Response:
|
||||
assert url == "https://ollama-hostname:8000/api/generate"
|
||||
assert headers == {
|
||||
"Content-Type": "application/json",
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Test building the Zapier tool, not running it."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -9,6 +10,9 @@ from langchain_community.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT
|
||||
from langchain_community.tools.zapier.tool import ZapierNLARunAction
|
||||
from langchain_community.utilities.zapier import ZapierNLAWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
|
||||
def test_default_base_prompt() -> None:
|
||||
"""Test that the default prompt is being inserted."""
|
||||
@ -136,7 +140,7 @@ def test_create_action_payload_with_params() -> None:
|
||||
assert payload["test"] == "test"
|
||||
|
||||
|
||||
async def test_apreview(mocker) -> None: # type: ignore[no-untyped-def]
|
||||
async def test_apreview(mocker: "MockerFixture") -> None:
|
||||
"""Test that the action payload with params is being created correctly."""
|
||||
tool = ZapierNLARunAction(
|
||||
action_id="test",
|
||||
@ -164,7 +168,7 @@ async def test_apreview(mocker) -> None: # type: ignore[no-untyped-def]
|
||||
)
|
||||
|
||||
|
||||
async def test_arun(mocker) -> None: # type: ignore[no-untyped-def]
|
||||
async def test_arun(mocker: "MockerFixture") -> None:
|
||||
"""Test that the action payload with params is being created correctly."""
|
||||
tool = ZapierNLARunAction(
|
||||
action_id="test",
|
||||
@ -188,7 +192,7 @@ async def test_arun(mocker) -> None: # type: ignore[no-untyped-def]
|
||||
)
|
||||
|
||||
|
||||
async def test_alist(mocker) -> None: # type: ignore[no-untyped-def]
|
||||
async def test_alist(mocker: "MockerFixture") -> None:
|
||||
"""Test that the action payload with params is being created correctly."""
|
||||
tool = ZapierNLARunAction(
|
||||
action_id="test",
|
||||
|
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -9,6 +9,9 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
|
||||
DEFAULT_VECTOR_DIMENSION = 4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from azure.search.documents.indexes.models import SearchIndex
|
||||
|
||||
|
||||
class FakeEmbeddingsWithDimension(FakeEmbeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
@ -36,7 +39,7 @@ DEFAULT_ACCESS_TOKEN = "myaccesstoken1"
|
||||
DEFAULT_EMBEDDING_MODEL = FakeEmbeddingsWithDimension()
|
||||
|
||||
|
||||
def mock_default_index(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||
def mock_default_index(*args: Any, **kwargs: Any) -> "SearchIndex":
|
||||
from azure.search.documents.indexes.models import (
|
||||
ExhaustiveKnnAlgorithmConfiguration,
|
||||
ExhaustiveKnnParameters,
|
||||
@ -155,12 +158,12 @@ def test_init_new_index() -> None:
|
||||
from azure.search.documents.indexes import SearchIndexClient
|
||||
from azure.search.documents.indexes.models import SearchIndex
|
||||
|
||||
def no_index(self, name: str): # type: ignore[no-untyped-def]
|
||||
def no_index(self: SearchIndexClient, name: str) -> SearchIndex:
|
||||
raise ResourceNotFoundError
|
||||
|
||||
created_index: Optional[SearchIndex] = None
|
||||
|
||||
def mock_create_index(self, index): # type: ignore[no-untyped-def]
|
||||
def mock_create_index(self: SearchIndexClient, index: SearchIndex) -> None:
|
||||
nonlocal created_index
|
||||
created_index = index
|
||||
|
||||
@ -203,7 +206,9 @@ def test_ids_used_correctly() -> None:
|
||||
def __init__(self) -> None:
|
||||
self.succeeded: bool = True
|
||||
|
||||
def mock_upload_documents(self, documents: List[object]) -> List[Response]: # type: ignore[no-untyped-def]
|
||||
def mock_upload_documents(
|
||||
self: SearchClient, documents: List[object]
|
||||
) -> List[Response]:
|
||||
# assume all documents uploaded successfuly
|
||||
response = [Response() for _ in documents]
|
||||
return response
|
||||
|
Loading…
Reference in New Issue
Block a user