feat(qdrant): ruff fixes and rules (#32500)

This commit is contained in:
Mason Daugherty 2025-08-11 12:43:41 -04:00 committed by GitHub
parent 9b3f3dc8d9
commit 374f414c91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 1659 additions and 1489 deletions

View File

@ -4,10 +4,10 @@ from langchain_qdrant.sparse_embeddings import SparseEmbeddings, SparseVector
from langchain_qdrant.vectorstores import Qdrant from langchain_qdrant.vectorstores import Qdrant
__all__ = [ __all__ = [
"FastEmbedSparse",
"Qdrant", "Qdrant",
"QdrantVectorStore", "QdrantVectorStore",
"RetrievalMode",
"SparseEmbeddings", "SparseEmbeddings",
"SparseVector", "SparseVector",
"FastEmbedSparse",
"RetrievalMode",
] ]

View File

@ -47,17 +47,17 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
X = np.array(X) X = np.array(X)
Y = np.array(Y) Y = np.array(Y)
if X.shape[1] != Y.shape[1]: if X.shape[1] != Y.shape[1]:
raise ValueError( msg = (
f"Number of columns in X and Y must be the same. X has shape {X.shape} " f"Number of columns in X and Y must be the same. X has shape {X.shape} "
f"and Y has shape {Y.shape}." f"and Y has shape {Y.shape}."
) )
raise ValueError(msg)
try: try:
import simsimd as simd import simsimd as simd
X = np.array(X, dtype=np.float32) X = np.array(X, dtype=np.float32)
Y = np.array(Y, dtype=np.float32) Y = np.array(Y, dtype=np.float32)
Z = 1 - np.array(simd.cdist(X, Y, metric="cosine")) return 1 - np.array(simd.cdist(X, Y, metric="cosine"))
return Z
except ImportError: except ImportError:
X_norm = np.linalg.norm(X, axis=1) X_norm = np.linalg.norm(X, axis=1)
Y_norm = np.linalg.norm(Y, axis=1) Y_norm = np.linalg.norm(Y, axis=1)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Optional from typing import Any, Optional
@ -17,9 +19,11 @@ class FastEmbedSparse(SparseEmbeddings):
parallel: Optional[int] = None, parallel: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """Sparse encoder implementation using FastEmbed.
Sparse encoder implementation using FastEmbed - https://qdrant.github.io/fastembed/
For a list of available models, see https://qdrant.github.io/fastembed/examples/Supported_Models/ Uses `FastEmbed <https://qdrant.github.io/fastembed/>`__ for sparse text
embeddings.
For a list of available models, see `the Qdrant docs <https://qdrant.github.io/fastembed/examples/Supported_Models/>`__.
Args: Args:
model_name (str): The name of the model to use. Defaults to `"Qdrant/bm25"`. model_name (str): The name of the model to use. Defaults to `"Qdrant/bm25"`.
@ -38,15 +42,19 @@ class FastEmbedSparse(SparseEmbeddings):
kwargs: Additional options to pass to fastembed.SparseTextEmbedding kwargs: Additional options to pass to fastembed.SparseTextEmbedding
Raises: Raises:
ValueError: If the model_name is not supported in SparseTextEmbedding. ValueError: If the model_name is not supported in SparseTextEmbedding.
""" """
try: try:
from fastembed import SparseTextEmbedding # type: ignore from fastembed import ( # type: ignore[import-not-found]
except ImportError: SparseTextEmbedding,
raise ValueError( )
except ImportError as err:
msg = (
"The 'fastembed' package is not installed. " "The 'fastembed' package is not installed. "
"Please install it with " "Please install it with "
"`pip install fastembed` or `pip install fastembed-gpu`." "`pip install fastembed` or `pip install fastembed-gpu`."
) )
raise ValueError(msg) from err
self._batch_size = batch_size self._batch_size = batch_size
self._parallel = parallel self._parallel = parallel
self._model = SparseTextEmbedding( self._model = SparseTextEmbedding(

View File

@ -185,8 +185,8 @@ class QdrantVectorStore(VectorStore):
distance: models.Distance = models.Distance.COSINE, distance: models.Distance = models.Distance.COSINE,
sparse_embedding: Optional[SparseEmbeddings] = None, sparse_embedding: Optional[SparseEmbeddings] = None,
sparse_vector_name: str = SPARSE_VECTOR_NAME, sparse_vector_name: str = SPARSE_VECTOR_NAME,
validate_embeddings: bool = True, validate_embeddings: bool = True, # noqa: FBT001, FBT002
validate_collection_config: bool = True, validate_collection_config: bool = True, # noqa: FBT001, FBT002
): ):
"""Initialize a new instance of `QdrantVectorStore`. """Initialize a new instance of `QdrantVectorStore`.
@ -232,6 +232,7 @@ class QdrantVectorStore(VectorStore):
Returns: Returns:
QdrantClient: An instance of ``QdrantClient``. QdrantClient: An instance of ``QdrantClient``.
""" """
return self._client return self._client
@ -244,11 +245,11 @@ class QdrantVectorStore(VectorStore):
Returns: Returns:
Embeddings: An instance of ``Embeddings``. Embeddings: An instance of ``Embeddings``.
""" """
if self._embeddings is None: if self._embeddings is None:
raise ValueError( msg = "Embeddings are `None`. Please set using the `embedding` parameter."
"Embeddings are `None`. Please set using the `embedding` parameter." raise ValueError(msg)
)
return self._embeddings return self._embeddings
@property @property
@ -260,12 +261,14 @@ class QdrantVectorStore(VectorStore):
Returns: Returns:
SparseEmbeddings: An instance of ``SparseEmbeddings``. SparseEmbeddings: An instance of ``SparseEmbeddings``.
""" """
if self._sparse_embeddings is None: if self._sparse_embeddings is None:
raise ValueError( msg = (
"Sparse embeddings are `None`. " "Sparse embeddings are `None`. "
"Please set using the `sparse_embedding` parameter." "Please set using the `sparse_embedding` parameter."
) )
raise ValueError(msg)
return self._sparse_embeddings return self._sparse_embeddings
@classmethod @classmethod
@ -280,8 +283,8 @@ class QdrantVectorStore(VectorStore):
url: Optional[str] = None, url: Optional[str] = None,
port: Optional[int] = 6333, port: Optional[int] = 6333,
grpc_port: int = 6334, grpc_port: int = 6334,
prefer_grpc: bool = False, prefer_grpc: bool = False, # noqa: FBT001, FBT002
https: Optional[bool] = None, https: Optional[bool] = None, # noqa: FBT001
api_key: Optional[str] = None, api_key: Optional[str] = None,
prefix: Optional[str] = None, prefix: Optional[str] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
@ -294,13 +297,13 @@ class QdrantVectorStore(VectorStore):
retrieval_mode: RetrievalMode = RetrievalMode.DENSE, retrieval_mode: RetrievalMode = RetrievalMode.DENSE,
sparse_embedding: Optional[SparseEmbeddings] = None, sparse_embedding: Optional[SparseEmbeddings] = None,
sparse_vector_name: str = SPARSE_VECTOR_NAME, sparse_vector_name: str = SPARSE_VECTOR_NAME,
collection_create_options: dict[str, Any] = {}, collection_create_options: Optional[dict[str, Any]] = None,
vector_params: dict[str, Any] = {}, vector_params: Optional[dict[str, Any]] = None,
sparse_vector_params: dict[str, Any] = {}, sparse_vector_params: Optional[dict[str, Any]] = None,
batch_size: int = 64, batch_size: int = 64,
force_recreate: bool = False, force_recreate: bool = False, # noqa: FBT001, FBT002
validate_embeddings: bool = True, validate_embeddings: bool = True, # noqa: FBT001, FBT002
validate_collection_config: bool = True, validate_collection_config: bool = True, # noqa: FBT001, FBT002
**kwargs: Any, **kwargs: Any,
) -> QdrantVectorStore: ) -> QdrantVectorStore:
"""Construct an instance of ``QdrantVectorStore`` from a list of texts. """Construct an instance of ``QdrantVectorStore`` from a list of texts.
@ -321,6 +324,12 @@ class QdrantVectorStore(VectorStore):
qdrant = Qdrant.from_texts(texts, embeddings, url="http://localhost:6333") qdrant = Qdrant.from_texts(texts, embeddings, url="http://localhost:6333")
""" """
if sparse_vector_params is None:
sparse_vector_params = {}
if vector_params is None:
vector_params = {}
if collection_create_options is None:
collection_create_options = {}
client_options = { client_options = {
"location": location, "location": location,
"url": url, "url": url,
@ -367,8 +376,8 @@ class QdrantVectorStore(VectorStore):
url: Optional[str] = None, url: Optional[str] = None,
port: Optional[int] = 6333, port: Optional[int] = 6333,
grpc_port: int = 6334, grpc_port: int = 6334,
prefer_grpc: bool = False, prefer_grpc: bool = False, # noqa: FBT001, FBT002
https: Optional[bool] = None, https: Optional[bool] = None, # noqa: FBT001
api_key: Optional[str] = None, api_key: Optional[str] = None,
prefix: Optional[str] = None, prefix: Optional[str] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
@ -380,8 +389,8 @@ class QdrantVectorStore(VectorStore):
vector_name: str = VECTOR_NAME, vector_name: str = VECTOR_NAME,
sparse_vector_name: str = SPARSE_VECTOR_NAME, sparse_vector_name: str = SPARSE_VECTOR_NAME,
sparse_embedding: Optional[SparseEmbeddings] = None, sparse_embedding: Optional[SparseEmbeddings] = None,
validate_embeddings: bool = True, validate_embeddings: bool = True, # noqa: FBT001, FBT002
validate_collection_config: bool = True, validate_collection_config: bool = True, # noqa: FBT001, FBT002
**kwargs: Any, **kwargs: Any,
) -> QdrantVectorStore: ) -> QdrantVectorStore:
"""Construct an instance of ``QdrantVectorStore`` from an existing collection """Construct an instance of ``QdrantVectorStore`` from an existing collection
@ -389,7 +398,8 @@ class QdrantVectorStore(VectorStore):
Returns: Returns:
QdrantVectorStore: A new instance of ``QdrantVectorStore``. QdrantVectorStore: A new instance of ``QdrantVectorStore``.
"""
""" # noqa: D205
client = QdrantClient( client = QdrantClient(
location=location, location=location,
url=url, url=url,
@ -420,7 +430,7 @@ class QdrantVectorStore(VectorStore):
validate_collection_config=validate_collection_config, validate_collection_config=validate_collection_config,
) )
def add_texts( # type: ignore def add_texts( # type: ignore[override]
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[list[dict]] = None, metadatas: Optional[list[dict]] = None,
@ -432,6 +442,7 @@ class QdrantVectorStore(VectorStore):
Returns: Returns:
List of ids from adding the texts into the vectorstore. List of ids from adding the texts into the vectorstore.
""" """
added_ids = [] added_ids = []
for batch_ids, points in self._generate_batches( for batch_ids, points in self._generate_batches(
@ -448,7 +459,7 @@ class QdrantVectorStore(VectorStore):
self, self,
query: str, query: str,
k: int = 4, k: int = 4,
filter: Optional[models.Filter] = None, filter: Optional[models.Filter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
offset: int = 0, offset: int = 0,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
@ -460,6 +471,7 @@ class QdrantVectorStore(VectorStore):
Returns: Returns:
List of Documents most similar to the query. List of Documents most similar to the query.
""" """
results = self.similarity_search_with_score( results = self.similarity_search_with_score(
query, query,
@ -478,7 +490,7 @@ class QdrantVectorStore(VectorStore):
self, self,
query: str, query: str,
k: int = 4, k: int = 4,
filter: Optional[models.Filter] = None, filter: Optional[models.Filter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
offset: int = 0, offset: int = 0,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
@ -490,6 +502,7 @@ class QdrantVectorStore(VectorStore):
Returns: Returns:
List of documents most similar to the query text and distance for each. List of documents most similar to the query text and distance for each.
""" """
query_options = { query_options = {
"collection_name": self.collection_name, "collection_name": self.collection_name,
@ -550,7 +563,8 @@ class QdrantVectorStore(VectorStore):
).points ).points
else: else:
raise ValueError(f"Invalid retrieval mode. {self.retrieval_mode}.") msg = f"Invalid retrieval mode. {self.retrieval_mode}."
raise ValueError(msg)
return [ return [
( (
self._document_from_point( self._document_from_point(
@ -568,7 +582,7 @@ class QdrantVectorStore(VectorStore):
self, self,
embedding: list[float], embedding: list[float],
k: int = 4, k: int = 4,
filter: Optional[models.Filter] = None, filter: Optional[models.Filter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
offset: int = 0, offset: int = 0,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
@ -579,6 +593,7 @@ class QdrantVectorStore(VectorStore):
Returns: Returns:
List of Documents most similar to the query and distance for each. List of Documents most similar to the query and distance for each.
""" """
qdrant_filter = filter qdrant_filter = filter
@ -621,7 +636,7 @@ class QdrantVectorStore(VectorStore):
self, self,
embedding: list[float], embedding: list[float],
k: int = 4, k: int = 4,
filter: Optional[models.Filter] = None, filter: Optional[models.Filter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
offset: int = 0, offset: int = 0,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
@ -632,6 +647,7 @@ class QdrantVectorStore(VectorStore):
Returns: Returns:
List of Documents most similar to the query. List of Documents most similar to the query.
""" """
results = self.similarity_search_with_score_by_vector( results = self.similarity_search_with_score_by_vector(
embedding, embedding,
@ -651,7 +667,7 @@ class QdrantVectorStore(VectorStore):
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
filter: Optional[models.Filter] = None, filter: Optional[models.Filter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
consistency: Optional[models.ReadConsistency] = None, consistency: Optional[models.ReadConsistency] = None,
@ -664,6 +680,7 @@ class QdrantVectorStore(VectorStore):
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
self._validate_collection_for_dense( self._validate_collection_for_dense(
self.client, self.client,
@ -692,7 +709,7 @@ class QdrantVectorStore(VectorStore):
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
filter: Optional[models.Filter] = None, filter: Optional[models.Filter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
consistency: Optional[models.ReadConsistency] = None, consistency: Optional[models.ReadConsistency] = None,
@ -705,6 +722,7 @@ class QdrantVectorStore(VectorStore):
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
results = self.max_marginal_relevance_search_with_score_by_vector( results = self.max_marginal_relevance_search_with_score_by_vector(
embedding, embedding,
@ -725,19 +743,21 @@ class QdrantVectorStore(VectorStore):
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
filter: Optional[models.Filter] = None, filter: Optional[models.Filter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
consistency: Optional[models.ReadConsistency] = None, consistency: Optional[models.ReadConsistency] = None,
**kwargs: Any, **kwargs: Any,
) -> list[tuple[Document, float]]: ) -> list[tuple[Document, float]]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents. among selected documents.
Returns: Returns:
List of Documents selected by maximal marginal relevance and distance for List of Documents selected by maximal marginal relevance and distance for
each. each.
""" """
results = self.client.query_points( results = self.client.query_points(
collection_name=self.collection_name, collection_name=self.collection_name,
@ -756,7 +776,7 @@ class QdrantVectorStore(VectorStore):
embeddings = [ embeddings = [
result.vector result.vector
if isinstance(result.vector, list) if isinstance(result.vector, list)
else result.vector.get(self.vector_name) # type: ignore else result.vector.get(self.vector_name) # type: ignore[union-attr]
for result in results for result in results
] ]
mmr_selected = maximal_marginal_relevance( mmr_selected = maximal_marginal_relevance(
@ -775,7 +795,7 @@ class QdrantVectorStore(VectorStore):
for i in mmr_selected for i in mmr_selected
] ]
def delete( # type: ignore def delete( # type: ignore[override]
self, self,
ids: Optional[list[str | int]] = None, ids: Optional[list[str | int]] = None,
**kwargs: Any, **kwargs: Any,
@ -788,6 +808,7 @@ class QdrantVectorStore(VectorStore):
Returns: Returns:
True if deletion is successful, False otherwise. True if deletion is successful, False otherwise.
""" """
result = self.client.delete( result = self.client.delete(
collection_name=self.collection_name, collection_name=self.collection_name,
@ -814,20 +835,28 @@ class QdrantVectorStore(VectorStore):
embedding: Optional[Embeddings] = None, embedding: Optional[Embeddings] = None,
retrieval_mode: RetrievalMode = RetrievalMode.DENSE, retrieval_mode: RetrievalMode = RetrievalMode.DENSE,
sparse_embedding: Optional[SparseEmbeddings] = None, sparse_embedding: Optional[SparseEmbeddings] = None,
client_options: dict[str, Any] = {}, client_options: Optional[dict[str, Any]] = None,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
distance: models.Distance = models.Distance.COSINE, distance: models.Distance = models.Distance.COSINE,
content_payload_key: str = CONTENT_KEY, content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY, metadata_payload_key: str = METADATA_KEY,
vector_name: str = VECTOR_NAME, vector_name: str = VECTOR_NAME,
sparse_vector_name: str = SPARSE_VECTOR_NAME, sparse_vector_name: str = SPARSE_VECTOR_NAME,
force_recreate: bool = False, force_recreate: bool = False, # noqa: FBT001, FBT002
collection_create_options: dict[str, Any] = {}, collection_create_options: Optional[dict[str, Any]] = None,
vector_params: dict[str, Any] = {}, vector_params: Optional[dict[str, Any]] = None,
sparse_vector_params: dict[str, Any] = {}, sparse_vector_params: Optional[dict[str, Any]] = None,
validate_embeddings: bool = True, validate_embeddings: bool = True, # noqa: FBT001, FBT002
validate_collection_config: bool = True, validate_collection_config: bool = True, # noqa: FBT001, FBT002
) -> QdrantVectorStore: ) -> QdrantVectorStore:
if sparse_vector_params is None:
sparse_vector_params = {}
if vector_params is None:
vector_params = {}
if collection_create_options is None:
collection_create_options = {}
if client_options is None:
client_options = {}
if validate_embeddings: if validate_embeddings:
cls._validate_embeddings(retrieval_mode, embedding, sparse_embedding) cls._validate_embeddings(retrieval_mode, embedding, sparse_embedding)
collection_name = collection_name or uuid.uuid4().hex collection_name = collection_name or uuid.uuid4().hex
@ -852,7 +881,7 @@ class QdrantVectorStore(VectorStore):
else: else:
vectors_config, sparse_vectors_config = {}, {} vectors_config, sparse_vectors_config = {}, {}
if retrieval_mode == RetrievalMode.DENSE: if retrieval_mode == RetrievalMode.DENSE:
partial_embeddings = embedding.embed_documents(["dummy_text"]) # type: ignore partial_embeddings = embedding.embed_documents(["dummy_text"]) # type: ignore[union-attr]
vector_params["size"] = len(partial_embeddings[0]) vector_params["size"] = len(partial_embeddings[0])
vector_params["distance"] = distance vector_params["distance"] = distance
@ -871,7 +900,7 @@ class QdrantVectorStore(VectorStore):
} }
elif retrieval_mode == RetrievalMode.HYBRID: elif retrieval_mode == RetrievalMode.HYBRID:
partial_embeddings = embedding.embed_documents(["dummy_text"]) # type: ignore partial_embeddings = embedding.embed_documents(["dummy_text"]) # type: ignore[union-attr]
vector_params["size"] = len(partial_embeddings[0]) vector_params["size"] = len(partial_embeddings[0])
vector_params["distance"] = distance vector_params["distance"] = distance
@ -894,7 +923,7 @@ class QdrantVectorStore(VectorStore):
client.create_collection(**collection_create_options) client.create_collection(**collection_create_options)
qdrant = cls( return cls(
client=client, client=client,
collection_name=collection_name, collection_name=collection_name,
embedding=embedding, embedding=embedding,
@ -908,7 +937,6 @@ class QdrantVectorStore(VectorStore):
validate_embeddings=False, validate_embeddings=False,
validate_collection_config=False, validate_collection_config=False,
) )
return qdrant
@staticmethod @staticmethod
def _cosine_relevance_score_fn(distance: float) -> float: def _cosine_relevance_score_fn(distance: float) -> float:
@ -916,25 +944,22 @@ class QdrantVectorStore(VectorStore):
return (distance + 1.0) / 2.0 return (distance + 1.0) / 2.0
def _select_relevance_score_fn(self) -> Callable[[float], float]: def _select_relevance_score_fn(self) -> Callable[[float], float]:
""" """Your "correct" relevance function may differ depending on a few things.
The "correct" relevance function may differ depending on a few things,
including: Including:
- The distance / similarity metric used by the VectorStore - The distance / similarity metric used by the VectorStore
- The scale of your embeddings (OpenAI's are unit normed. Many others are not!) - The scale of your embeddings (OpenAI's are unit normed. Many others are not!)
- Embedding dimensionality - Embedding dimensionality
- etc. - etc.
""" """
if self.distance == models.Distance.COSINE: if self.distance == models.Distance.COSINE:
return self._cosine_relevance_score_fn return self._cosine_relevance_score_fn
elif self.distance == models.Distance.DOT: if self.distance == models.Distance.DOT:
return self._max_inner_product_relevance_score_fn return self._max_inner_product_relevance_score_fn
elif self.distance == models.Distance.EUCLID: if self.distance == models.Distance.EUCLID:
return self._euclidean_relevance_score_fn return self._euclidean_relevance_score_fn
else: msg = "Unknown distance strategy, must be COSINE, DOT, or EUCLID."
raise ValueError( raise ValueError(msg)
"Unknown distance strategy, must be COSINE, DOT, or EUCLID."
)
@classmethod @classmethod
def _document_from_point( def _document_from_point(
@ -996,10 +1021,11 @@ class QdrantVectorStore(VectorStore):
payloads = [] payloads = []
for i, text in enumerate(texts): for i, text in enumerate(texts):
if text is None: if text is None:
raise ValueError( msg = (
"At least one of the texts is None. Please remove it before " "At least one of the texts is None. Please remove it before "
"calling .from_texts or .add_texts." "calling .from_texts or .add_texts."
) )
raise ValueError(msg)
metadata = metadatas[i] if metadatas is not None else None metadata = metadatas[i] if metadatas is not None else None
payloads.append( payloads.append(
{ {
@ -1023,7 +1049,7 @@ class QdrantVectorStore(VectorStore):
for vector in batch_embeddings for vector in batch_embeddings
] ]
elif self.retrieval_mode == RetrievalMode.SPARSE: if self.retrieval_mode == RetrievalMode.SPARSE:
batch_sparse_embeddings = self.sparse_embeddings.embed_documents( batch_sparse_embeddings = self.sparse_embeddings.embed_documents(
list(texts) list(texts)
) )
@ -1036,14 +1062,13 @@ class QdrantVectorStore(VectorStore):
for vector in batch_sparse_embeddings for vector in batch_sparse_embeddings
] ]
elif self.retrieval_mode == RetrievalMode.HYBRID: if self.retrieval_mode == RetrievalMode.HYBRID:
dense_embeddings = self.embeddings.embed_documents(list(texts)) dense_embeddings = self.embeddings.embed_documents(list(texts))
sparse_embeddings = self.sparse_embeddings.embed_documents(list(texts)) sparse_embeddings = self.sparse_embeddings.embed_documents(list(texts))
if len(dense_embeddings) != len(sparse_embeddings): if len(dense_embeddings) != len(sparse_embeddings):
raise ValueError( msg = "Mismatched length between dense and sparse embeddings."
"Mismatched length between dense and sparse embeddings." raise ValueError(msg)
)
return [ return [
{ {
@ -1057,10 +1082,8 @@ class QdrantVectorStore(VectorStore):
) )
] ]
else: msg = f"Unknown retrieval mode. {self.retrieval_mode} to build vectors."
raise ValueError( raise ValueError(msg)
f"Unknown retrieval mode. {self.retrieval_mode} to build vectors."
)
@classmethod @classmethod
def _validate_collection_config( def _validate_collection_config(
@ -1106,51 +1129,55 @@ class QdrantVectorStore(VectorStore):
if isinstance(vector_config, dict): if isinstance(vector_config, dict):
# vector_config is a Dict[str, VectorParams] # vector_config is a Dict[str, VectorParams]
if vector_name not in vector_config: if vector_name not in vector_config:
raise QdrantVectorStoreError( msg = (
f"Existing Qdrant collection {collection_name} does not " f"Existing Qdrant collection {collection_name} does not "
f"contain dense vector named {vector_name}. " f"contain dense vector named {vector_name}. "
"Did you mean one of the " "Did you mean one of the "
f"existing vectors: {', '.join(vector_config.keys())}? " # type: ignore f"existing vectors: {', '.join(vector_config.keys())}? " # type: ignore[union-attr]
f"If you want to recreate the collection, set `force_recreate` " f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`." f"parameter to `True`."
) )
raise QdrantVectorStoreError(msg)
# Get the VectorParams object for the specified vector_name # Get the VectorParams object for the specified vector_name
vector_config = vector_config[vector_name] # type: ignore vector_config = vector_config[vector_name] # type: ignore[assignment, index]
else: # vector_config is an instance of VectorParams
# vector_config is an instance of VectorParams # Case of a collection with single/unnamed vector.
# Case of a collection with single/unnamed vector. elif vector_name != "":
if vector_name != "": msg = (
raise QdrantVectorStoreError( f"Existing Qdrant collection {collection_name} is built "
f"Existing Qdrant collection {collection_name} is built " "with unnamed dense vector. "
"with unnamed dense vector. " f"If you want to reuse it, set `vector_name` to ''(empty string)."
f"If you want to reuse it, set `vector_name` to ''(empty string)." f"If you want to recreate the collection, "
f"If you want to recreate the collection, " "set `force_recreate` to `True`."
"set `force_recreate` to `True`." )
) raise QdrantVectorStoreError(msg)
if vector_config is None: if vector_config is None:
raise ValueError("VectorParams is None") msg = "VectorParams is None"
raise ValueError(msg)
if isinstance(dense_embeddings, Embeddings): if isinstance(dense_embeddings, Embeddings):
vector_size = len(dense_embeddings.embed_documents(["dummy_text"])[0]) vector_size = len(dense_embeddings.embed_documents(["dummy_text"])[0])
elif isinstance(dense_embeddings, list): elif isinstance(dense_embeddings, list):
vector_size = len(dense_embeddings) vector_size = len(dense_embeddings)
else: else:
raise ValueError("Invalid `embeddings` type.") msg = "Invalid `embeddings` type."
raise ValueError(msg)
if vector_config.size != vector_size: if vector_config.size != vector_size:
raise QdrantVectorStoreError( msg = (
f"Existing Qdrant collection is configured for dense vectors with " f"Existing Qdrant collection is configured for dense vectors with "
f"{vector_config.size} dimensions. " f"{vector_config.size} dimensions. "
f"Selected embeddings are {vector_size}-dimensional. " f"Selected embeddings are {vector_size}-dimensional. "
f"If you want to recreate the collection, set `force_recreate` " f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`." f"parameter to `True`."
) )
raise QdrantVectorStoreError(msg)
if vector_config.distance != distance: if vector_config.distance != distance:
raise QdrantVectorStoreError( msg = (
f"Existing Qdrant collection is configured for " f"Existing Qdrant collection is configured for "
f"{vector_config.distance.name} similarity, but requested " f"{vector_config.distance.name} similarity, but requested "
f"{distance.upper()}. Please set `distance` parameter to " f"{distance.upper()}. Please set `distance` parameter to "
@ -1158,6 +1185,7 @@ class QdrantVectorStore(VectorStore):
f"If you want to recreate the collection, set `force_recreate` " f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`." f"parameter to `True`."
) )
raise QdrantVectorStoreError(msg)
@classmethod @classmethod
def _validate_collection_for_sparse( def _validate_collection_for_sparse(
@ -1173,12 +1201,13 @@ class QdrantVectorStore(VectorStore):
sparse_vector_config is None sparse_vector_config is None
or sparse_vector_name not in sparse_vector_config or sparse_vector_name not in sparse_vector_config
): ):
raise QdrantVectorStoreError( msg = (
f"Existing Qdrant collection {collection_name} does not " f"Existing Qdrant collection {collection_name} does not "
f"contain sparse vectors named {sparse_vector_name}. " f"contain sparse vectors named {sparse_vector_name}. "
f"If you want to recreate the collection, set `force_recreate` " f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`." f"parameter to `True`."
) )
raise QdrantVectorStoreError(msg)
@classmethod @classmethod
def _validate_embeddings( def _validate_embeddings(
@ -1188,19 +1217,18 @@ class QdrantVectorStore(VectorStore):
sparse_embedding: Optional[SparseEmbeddings], sparse_embedding: Optional[SparseEmbeddings],
) -> None: ) -> None:
if retrieval_mode == RetrievalMode.DENSE and embedding is None: if retrieval_mode == RetrievalMode.DENSE and embedding is None:
raise ValueError( msg = "'embedding' cannot be None when retrieval mode is 'dense'"
"'embedding' cannot be None when retrieval mode is 'dense'" raise ValueError(msg)
)
elif retrieval_mode == RetrievalMode.SPARSE and sparse_embedding is None: if retrieval_mode == RetrievalMode.SPARSE and sparse_embedding is None:
raise ValueError( msg = "'sparse_embedding' cannot be None when retrieval mode is 'sparse'"
"'sparse_embedding' cannot be None when retrieval mode is 'sparse'" raise ValueError(msg)
)
elif retrieval_mode == RetrievalMode.HYBRID and any( if retrieval_mode == RetrievalMode.HYBRID and any(
[embedding is None, sparse_embedding is None] [embedding is None, sparse_embedding is None]
): ):
raise ValueError( msg = (
"Both 'embedding' and 'sparse_embedding' cannot be None " "Both 'embedding' and 'sparse_embedding' cannot be None "
"when retrieval mode is 'hybrid'" "when retrieval mode is 'hybrid'"
) )
raise ValueError(msg)

View File

@ -5,9 +5,7 @@ from pydantic import BaseModel, Field
class SparseVector(BaseModel, extra="forbid"): class SparseVector(BaseModel, extra="forbid"):
""" """Sparse vector structure."""
Sparse vector structure
"""
indices: list[int] = Field(..., description="indices must be unique") indices: list[int] = Field(..., description="indices must be unique")
values: list[float] = Field( values: list[float] = Field(

View File

@ -7,13 +7,7 @@ import warnings
from collections.abc import AsyncGenerator, Generator, Iterable, Sequence from collections.abc import AsyncGenerator, Generator, Iterable, Sequence
from itertools import islice from itertools import islice
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import TYPE_CHECKING, Any, Callable, Optional, Union
TYPE_CHECKING,
Any,
Callable,
Optional,
Union,
)
import numpy as np import numpy as np
from langchain_core._api.deprecation import deprecated from langchain_core._api.deprecation import deprecated
@ -37,10 +31,11 @@ class QdrantException(Exception):
def sync_call_fallback(method: Callable) -> Callable: def sync_call_fallback(method: Callable) -> Callable:
""" """Call the synchronous method if the async method is not implemented.
Decorator to call the synchronous method of the class if the async method is not
implemented. This decorator might be only used for the methods that are defined This decorator should only be used for methods that are defined as async in the
as async in the class. class.
""" """
@functools.wraps(method) @functools.wraps(method)
@ -93,27 +88,29 @@ class Qdrant(VectorStore):
): ):
"""Initialize with necessary components.""" """Initialize with necessary components."""
if not isinstance(client, QdrantClient): if not isinstance(client, QdrantClient):
raise ValueError( msg = (
f"client should be an instance of qdrant_client.QdrantClient, " f"client should be an instance of qdrant_client.QdrantClient, "
f"got {type(client)}" f"got {type(client)}"
) )
raise ValueError(msg)
if async_client is not None and not isinstance(async_client, AsyncQdrantClient): if async_client is not None and not isinstance(async_client, AsyncQdrantClient):
raise ValueError( msg = (
f"async_client should be an instance of qdrant_client.AsyncQdrantClient" f"async_client should be an instance of qdrant_client.AsyncQdrantClient"
f"got {type(async_client)}" f"got {type(async_client)}"
) )
raise ValueError(msg)
if embeddings is None and embedding_function is None: if embeddings is None and embedding_function is None:
raise ValueError( msg = "`embeddings` value can't be None. Pass `embeddings` instance."
"`embeddings` value can't be None. Pass `embeddings` instance." raise ValueError(msg)
)
if embeddings is not None and embedding_function is not None: if embeddings is not None and embedding_function is not None:
raise ValueError( msg = (
"Both `embeddings` and `embedding_function` are passed. " "Both `embeddings` and `embedding_function` are passed. "
"Use `embeddings` only." "Use `embeddings` only."
) )
raise ValueError(msg)
self._embeddings = embeddings self._embeddings = embeddings
self._embeddings_function = embedding_function self._embeddings_function = embedding_function
@ -127,13 +124,15 @@ class Qdrant(VectorStore):
if embedding_function is not None: if embedding_function is not None:
warnings.warn( warnings.warn(
"Using `embedding_function` is deprecated. " "Using `embedding_function` is deprecated. "
"Pass `Embeddings` instance to `embeddings` instead." "Pass `Embeddings` instance to `embeddings` instead.",
stacklevel=2,
) )
if not isinstance(embeddings, Embeddings): if not isinstance(embeddings, Embeddings):
warnings.warn( warnings.warn(
"`embeddings` should be an instance of `Embeddings`." "`embeddings` should be an instance of `Embeddings`."
"Using `embeddings` as `embedding_function` which is deprecated" "Using `embeddings` as `embedding_function` which is deprecated",
stacklevel=2,
) )
self._embeddings_function = embeddings self._embeddings_function = embeddings
self._embeddings = None self._embeddings = None
@ -163,9 +162,11 @@ class Qdrant(VectorStore):
batch_size: batch_size:
How many vectors upload per-request. How many vectors upload per-request.
Default: ``64`` Default: ``64``
**kwargs: Additional keyword arguments.
Returns: Returns:
List of ids from adding the texts into the vectorstore. List of ids from adding the texts into the vectorstore.
""" """
added_ids = [] added_ids = []
for batch_ids, points in self._generate_rest_batches( for batch_ids, points in self._generate_rest_batches(
@ -198,16 +199,17 @@ class Qdrant(VectorStore):
batch_size: batch_size:
How many vectors upload per-request. How many vectors upload per-request.
Default: ``64`` Default: ``64``
**kwargs: Additional keyword arguments.
Returns: Returns:
List of ids from adding the texts into the vectorstore. List of ids from adding the texts into the vectorstore.
""" """
if self.async_client is None or isinstance( if self.async_client is None or isinstance(
self.async_client._client, AsyncQdrantLocal self.async_client._client, AsyncQdrantLocal
): ):
raise NotImplementedError( msg = "QdrantLocal cannot interoperate with sync and async clients"
"QdrantLocal cannot interoperate with sync and async clients" raise NotImplementedError(msg)
)
added_ids = [] added_ids = []
async for batch_ids, points in self._agenerate_rest_batches( async for batch_ids, points in self._agenerate_rest_batches(
@ -224,7 +226,7 @@ class Qdrant(VectorStore):
self, self,
query: str, query: str,
k: int = 4, k: int = 4,
filter: Optional[MetadataFilter] = None, filter: Optional[MetadataFilter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
offset: int = 0, offset: int = 0,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
@ -264,6 +266,7 @@ class Qdrant(VectorStore):
Returns: Returns:
List of Documents most similar to the query. List of Documents most similar to the query.
""" """
results = self.similarity_search_with_score( results = self.similarity_search_with_score(
query, query,
@ -282,16 +285,20 @@ class Qdrant(VectorStore):
self, self,
query: str, query: str,
k: int = 4, k: int = 4,
filter: Optional[MetadataFilter] = None, filter: Optional[MetadataFilter] = None, # noqa: A002
**kwargs: Any, **kwargs: Any,
) -> list[Document]: ) -> list[Document]:
"""Return docs most similar to query. """Return docs most similar to query.
Args: Args:
query: Text to look up documents similar to. query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
filter: Filter by metadata. Defaults to None. filter: Filter by metadata. Defaults to None.
**kwargs: Additional keyword arguments.
Returns: Returns:
List of Documents most similar to the query. List of Documents most similar to the query.
""" """
results = await self.asimilarity_search_with_score(query, k, filter, **kwargs) results = await self.asimilarity_search_with_score(query, k, filter, **kwargs)
return list(map(itemgetter(0), results)) return list(map(itemgetter(0), results))
@ -300,7 +307,7 @@ class Qdrant(VectorStore):
self, self,
query: str, query: str,
k: int = 4, k: int = 4,
filter: Optional[MetadataFilter] = None, filter: Optional[MetadataFilter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
offset: int = 0, offset: int = 0,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
@ -340,6 +347,7 @@ class Qdrant(VectorStore):
Returns: Returns:
List of documents most similar to the query text and distance for each. List of documents most similar to the query text and distance for each.
""" """
return self.similarity_search_with_score_by_vector( return self.similarity_search_with_score_by_vector(
self._embed_query(query), self._embed_query(query),
@ -357,7 +365,7 @@ class Qdrant(VectorStore):
self, self,
query: str, query: str,
k: int = 4, k: int = 4,
filter: Optional[MetadataFilter] = None, filter: Optional[MetadataFilter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
offset: int = 0, offset: int = 0,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
@ -398,6 +406,7 @@ class Qdrant(VectorStore):
Returns: Returns:
List of documents most similar to the query text and distance for each. List of documents most similar to the query text and distance for each.
""" """
query_embedding = await self._aembed_query(query) query_embedding = await self._aembed_query(query)
return await self.asimilarity_search_with_score_by_vector( return await self.asimilarity_search_with_score_by_vector(
@ -415,7 +424,7 @@ class Qdrant(VectorStore):
self, self,
embedding: list[float], embedding: list[float],
k: int = 4, k: int = 4,
filter: Optional[MetadataFilter] = None, filter: Optional[MetadataFilter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
offset: int = 0, offset: int = 0,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
@ -455,6 +464,7 @@ class Qdrant(VectorStore):
Returns: Returns:
List of Documents most similar to the query. List of Documents most similar to the query.
""" """
results = self.similarity_search_with_score_by_vector( results = self.similarity_search_with_score_by_vector(
embedding, embedding,
@ -473,7 +483,7 @@ class Qdrant(VectorStore):
self, self,
embedding: list[float], embedding: list[float],
k: int = 4, k: int = 4,
filter: Optional[MetadataFilter] = None, filter: Optional[MetadataFilter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
offset: int = 0, offset: int = 0,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
@ -514,6 +524,7 @@ class Qdrant(VectorStore):
Returns: Returns:
List of Documents most similar to the query. List of Documents most similar to the query.
""" """
results = await self.asimilarity_search_with_score_by_vector( results = await self.asimilarity_search_with_score_by_vector(
embedding, embedding,
@ -531,7 +542,7 @@ class Qdrant(VectorStore):
self, self,
embedding: list[float], embedding: list[float],
k: int = 4, k: int = 4,
filter: Optional[MetadataFilter] = None, filter: Optional[MetadataFilter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
offset: int = 0, offset: int = 0,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
@ -571,6 +582,7 @@ class Qdrant(VectorStore):
Returns: Returns:
List of documents most similar to the query text and distance for each. List of documents most similar to the query text and distance for each.
""" """
if filter is not None and isinstance(filter, dict): if filter is not None and isinstance(filter, dict):
warnings.warn( warnings.warn(
@ -578,6 +590,7 @@ class Qdrant(VectorStore):
"filters directly: " "filters directly: "
"https://qdrant.tech/documentation/concepts/filtering/", "https://qdrant.tech/documentation/concepts/filtering/",
DeprecationWarning, DeprecationWarning,
stacklevel=2,
) )
qdrant_filter = self._qdrant_filter_from_dict(filter) qdrant_filter = self._qdrant_filter_from_dict(filter)
else: else:
@ -618,7 +631,7 @@ class Qdrant(VectorStore):
self, self,
embedding: list[float], embedding: list[float],
k: int = 4, k: int = 4,
filter: Optional[MetadataFilter] = None, filter: Optional[MetadataFilter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
offset: int = 0, offset: int = 0,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
@ -659,20 +672,20 @@ class Qdrant(VectorStore):
Returns: Returns:
List of documents most similar to the query text and distance for each. List of documents most similar to the query text and distance for each.
"""
"""
if self.async_client is None or isinstance( if self.async_client is None or isinstance(
self.async_client._client, AsyncQdrantLocal self.async_client._client, AsyncQdrantLocal
): ):
raise NotImplementedError( msg = "QdrantLocal cannot interoperate with sync and async clients"
"QdrantLocal cannot interoperate with sync and async clients" raise NotImplementedError(msg)
)
if filter is not None and isinstance(filter, dict): if filter is not None and isinstance(filter, dict):
warnings.warn( warnings.warn(
"Using dict as a `filter` is deprecated. Please use qdrant-client " "Using dict as a `filter` is deprecated. Please use qdrant-client "
"filters directly: " "filters directly: "
"https://qdrant.tech/documentation/concepts/filtering/", "https://qdrant.tech/documentation/concepts/filtering/",
DeprecationWarning, DeprecationWarning,
stacklevel=2,
) )
qdrant_filter = self._qdrant_filter_from_dict(filter) qdrant_filter = self._qdrant_filter_from_dict(filter)
else: else:
@ -714,7 +727,7 @@ class Qdrant(VectorStore):
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
filter: Optional[MetadataFilter] = None, filter: Optional[MetadataFilter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
consistency: Optional[models.ReadConsistency] = None, consistency: Optional[models.ReadConsistency] = None,
@ -755,8 +768,10 @@ class Qdrant(VectorStore):
- 'all' - query all replicas, and return values present in all replicas - 'all' - query all replicas, and return values present in all replicas
**kwargs: **kwargs:
Any other named arguments to pass through to QdrantClient.search() Any other named arguments to pass through to QdrantClient.search()
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
query_embedding = self._embed_query(query) query_embedding = self._embed_query(query)
return self.max_marginal_relevance_search_by_vector( return self.max_marginal_relevance_search_by_vector(
@ -778,7 +793,7 @@ class Qdrant(VectorStore):
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
filter: Optional[MetadataFilter] = None, filter: Optional[MetadataFilter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
consistency: Optional[models.ReadConsistency] = None, consistency: Optional[models.ReadConsistency] = None,
@ -820,8 +835,10 @@ class Qdrant(VectorStore):
**kwargs: **kwargs:
Any other named arguments to pass through to Any other named arguments to pass through to
AsyncQdrantClient.Search(). AsyncQdrantClient.Search().
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
query_embedding = await self._aembed_query(query) query_embedding = await self._aembed_query(query)
return await self.amax_marginal_relevance_search_by_vector( return await self.amax_marginal_relevance_search_by_vector(
@ -842,7 +859,7 @@ class Qdrant(VectorStore):
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
filter: Optional[MetadataFilter] = None, filter: Optional[MetadataFilter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
consistency: Optional[models.ReadConsistency] = None, consistency: Optional[models.ReadConsistency] = None,
@ -882,8 +899,10 @@ class Qdrant(VectorStore):
- 'all' - query all replicas, and return values present in all replicas - 'all' - query all replicas, and return values present in all replicas
**kwargs: **kwargs:
Any other named arguments to pass through to QdrantClient.search() Any other named arguments to pass through to QdrantClient.search()
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
results = self.max_marginal_relevance_search_with_score_by_vector( results = self.max_marginal_relevance_search_with_score_by_vector(
embedding, embedding,
@ -905,15 +924,17 @@ class Qdrant(VectorStore):
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
filter: Optional[MetadataFilter] = None, filter: Optional[MetadataFilter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
consistency: Optional[models.ReadConsistency] = None, consistency: Optional[models.ReadConsistency] = None,
**kwargs: Any, **kwargs: Any,
) -> list[Document]: ) -> list[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents. among selected documents.
Args: Args:
embedding: Embedding vector to look up documents similar to. embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
@ -945,9 +966,11 @@ class Qdrant(VectorStore):
**kwargs: **kwargs:
Any other named arguments to pass through to Any other named arguments to pass through to
AsyncQdrantClient.Search(). AsyncQdrantClient.Search().
Returns: Returns:
List of Documents selected by maximal marginal relevance and distance for List of Documents selected by maximal marginal relevance and distance for
each. each.
""" """
results = await self.amax_marginal_relevance_search_with_score_by_vector( results = await self.amax_marginal_relevance_search_with_score_by_vector(
embedding, embedding,
@ -968,15 +991,17 @@ class Qdrant(VectorStore):
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
filter: Optional[MetadataFilter] = None, filter: Optional[MetadataFilter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
consistency: Optional[models.ReadConsistency] = None, consistency: Optional[models.ReadConsistency] = None,
**kwargs: Any, **kwargs: Any,
) -> list[tuple[Document, float]]: ) -> list[tuple[Document, float]]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents. among selected documents.
Args: Args:
embedding: Embedding vector to look up documents similar to. embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
@ -1007,9 +1032,11 @@ class Qdrant(VectorStore):
- 'all' - query all replicas, and return values present in all replicas - 'all' - query all replicas, and return values present in all replicas
**kwargs: **kwargs:
Any other named arguments to pass through to QdrantClient.search() Any other named arguments to pass through to QdrantClient.search()
Returns: Returns:
List of Documents selected by maximal marginal relevance and distance for List of Documents selected by maximal marginal relevance and distance for
each. each.
""" """
query_vector = embedding query_vector = embedding
if self.vector_name is not None: if self.vector_name is not None:
@ -1056,15 +1083,17 @@ class Qdrant(VectorStore):
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
filter: Optional[MetadataFilter] = None, filter: Optional[MetadataFilter] = None, # noqa: A002
search_params: Optional[models.SearchParams] = None, search_params: Optional[models.SearchParams] = None,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
consistency: Optional[models.ReadConsistency] = None, consistency: Optional[models.ReadConsistency] = None,
**kwargs: Any, **kwargs: Any,
) -> list[tuple[Document, float]]: ) -> list[tuple[Document, float]]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents. among selected documents.
Args: Args:
embedding: Embedding vector to look up documents similar to. embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
@ -1074,16 +1103,22 @@ class Qdrant(VectorStore):
of diversity among the results with 0 corresponding of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity. to maximum diversity and 1 to minimum diversity.
Defaults to 0.5. Defaults to 0.5.
filter: Filter by metadata. Defaults to None.
search_params: Additional search params.
score_threshold: Define a minimal score threshold for the result.
consistency: Read consistency of the search.
**kwargs: Additional keyword arguments.
Returns: Returns:
List of Documents selected by maximal marginal relevance and distance for List of Documents selected by maximal marginal relevance and distance for
each. each.
""" """
if self.async_client is None or isinstance( if self.async_client is None or isinstance(
self.async_client._client, AsyncQdrantLocal self.async_client._client, AsyncQdrantLocal
): ):
raise NotImplementedError( msg = "QdrantLocal cannot interoperate with sync and async clients"
"QdrantLocal cannot interoperate with sync and async clients" raise NotImplementedError(msg)
)
query_vector = embedding query_vector = embedding
if self.vector_name is not None: if self.vector_name is not None:
query_vector = (self.vector_name, query_vector) # type: ignore[assignment] query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
@ -1131,8 +1166,8 @@ class Qdrant(VectorStore):
Returns: Returns:
True if deletion is successful, False otherwise. True if deletion is successful, False otherwise.
"""
"""
result = self.client.delete( result = self.client.delete(
collection_name=self.collection_name, collection_name=self.collection_name,
points_selector=ids, points_selector=ids,
@ -1151,13 +1186,13 @@ class Qdrant(VectorStore):
Returns: Returns:
True if deletion is successful, False otherwise. True if deletion is successful, False otherwise.
""" """
if self.async_client is None or isinstance( if self.async_client is None or isinstance(
self.async_client._client, AsyncQdrantLocal self.async_client._client, AsyncQdrantLocal
): ):
raise NotImplementedError( msg = "QdrantLocal cannot interoperate with sync and async clients"
"QdrantLocal cannot interoperate with sync and async clients" raise NotImplementedError(msg)
)
result = await self.async_client.delete( result = await self.async_client.delete(
collection_name=self.collection_name, collection_name=self.collection_name,
@ -1177,8 +1212,8 @@ class Qdrant(VectorStore):
url: Optional[str] = None, url: Optional[str] = None,
port: Optional[int] = 6333, port: Optional[int] = 6333,
grpc_port: int = 6334, grpc_port: int = 6334,
prefer_grpc: bool = False, prefer_grpc: bool = False, # noqa: FBT001, FBT002
https: Optional[bool] = None, https: Optional[bool] = None, # noqa: FBT001
api_key: Optional[str] = None, api_key: Optional[str] = None,
prefix: Optional[str] = None, prefix: Optional[str] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
@ -1193,14 +1228,14 @@ class Qdrant(VectorStore):
shard_number: Optional[int] = None, shard_number: Optional[int] = None,
replication_factor: Optional[int] = None, replication_factor: Optional[int] = None,
write_consistency_factor: Optional[int] = None, write_consistency_factor: Optional[int] = None,
on_disk_payload: Optional[bool] = None, on_disk_payload: Optional[bool] = None, # noqa: FBT001
hnsw_config: Optional[models.HnswConfigDiff] = None, hnsw_config: Optional[models.HnswConfigDiff] = None,
optimizers_config: Optional[models.OptimizersConfigDiff] = None, optimizers_config: Optional[models.OptimizersConfigDiff] = None,
wal_config: Optional[models.WalConfigDiff] = None, wal_config: Optional[models.WalConfigDiff] = None,
quantization_config: Optional[models.QuantizationConfig] = None, quantization_config: Optional[models.QuantizationConfig] = None,
init_from: Optional[models.InitFrom] = None, init_from: Optional[models.InitFrom] = None,
on_disk: Optional[bool] = None, on_disk: Optional[bool] = None, # noqa: FBT001
force_recreate: bool = False, force_recreate: bool = False, # noqa: FBT001, FBT002
**kwargs: Any, **kwargs: Any,
) -> Qdrant: ) -> Qdrant:
"""Construct Qdrant wrapper from a list of texts. """Construct Qdrant wrapper from a list of texts.
@ -1287,6 +1322,8 @@ class Qdrant(VectorStore):
Params for quantization, if None - quantization will be disabled Params for quantization, if None - quantization will be disabled
init_from: init_from:
Use data stored in another collection to initialize this collection Use data stored in another collection to initialize this collection
on_disk:
If true - vectors will be stored on disk, reducing memory usage.
force_recreate: force_recreate:
Force recreating the collection Force recreating the collection
**kwargs: **kwargs:
@ -1354,8 +1391,8 @@ class Qdrant(VectorStore):
url: Optional[str] = None, url: Optional[str] = None,
port: Optional[int] = 6333, port: Optional[int] = 6333,
grpc_port: int = 6334, grpc_port: int = 6334,
prefer_grpc: bool = False, prefer_grpc: bool = False, # noqa: FBT001, FBT002
https: Optional[bool] = None, https: Optional[bool] = None, # noqa: FBT001
api_key: Optional[str] = None, api_key: Optional[str] = None,
prefix: Optional[str] = None, prefix: Optional[str] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
@ -1366,14 +1403,14 @@ class Qdrant(VectorStore):
vector_name: Optional[str] = VECTOR_NAME, vector_name: Optional[str] = VECTOR_NAME,
**kwargs: Any, **kwargs: Any,
) -> Qdrant: ) -> Qdrant:
""" """Get instance of an existing Qdrant collection.
Get instance of an existing Qdrant collection.
This method will return the instance of the store without inserting any new
embeddings
"""
This method will return the instance of the store without inserting any new
embeddings.
"""
if collection_name is None: if collection_name is None:
raise ValueError("Must specify collection_name. Received None.") msg = "Must specify collection_name. Received None."
raise ValueError(msg)
client, async_client = cls._generate_clients( client, async_client = cls._generate_clients(
location=location, location=location,
@ -1412,8 +1449,8 @@ class Qdrant(VectorStore):
url: Optional[str] = None, url: Optional[str] = None,
port: Optional[int] = 6333, port: Optional[int] = 6333,
grpc_port: int = 6334, grpc_port: int = 6334,
prefer_grpc: bool = False, prefer_grpc: bool = False, # noqa: FBT001, FBT002
https: Optional[bool] = None, https: Optional[bool] = None, # noqa: FBT001
api_key: Optional[str] = None, api_key: Optional[str] = None,
prefix: Optional[str] = None, prefix: Optional[str] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
@ -1428,14 +1465,14 @@ class Qdrant(VectorStore):
shard_number: Optional[int] = None, shard_number: Optional[int] = None,
replication_factor: Optional[int] = None, replication_factor: Optional[int] = None,
write_consistency_factor: Optional[int] = None, write_consistency_factor: Optional[int] = None,
on_disk_payload: Optional[bool] = None, on_disk_payload: Optional[bool] = None, # noqa: FBT001
hnsw_config: Optional[models.HnswConfigDiff] = None, hnsw_config: Optional[models.HnswConfigDiff] = None,
optimizers_config: Optional[models.OptimizersConfigDiff] = None, optimizers_config: Optional[models.OptimizersConfigDiff] = None,
wal_config: Optional[models.WalConfigDiff] = None, wal_config: Optional[models.WalConfigDiff] = None,
quantization_config: Optional[models.QuantizationConfig] = None, quantization_config: Optional[models.QuantizationConfig] = None,
init_from: Optional[models.InitFrom] = None, init_from: Optional[models.InitFrom] = None,
on_disk: Optional[bool] = None, on_disk: Optional[bool] = None, # noqa: FBT001
force_recreate: bool = False, force_recreate: bool = False, # noqa: FBT001, FBT002
**kwargs: Any, **kwargs: Any,
) -> Qdrant: ) -> Qdrant:
"""Construct Qdrant wrapper from a list of texts. """Construct Qdrant wrapper from a list of texts.
@ -1522,6 +1559,12 @@ class Qdrant(VectorStore):
Params for quantization, if None - quantization will be disabled Params for quantization, if None - quantization will be disabled
init_from: init_from:
Use data stored in another collection to initialize this collection Use data stored in another collection to initialize this collection
on_disk:
If true - point`s payload will not be stored in memory.
It will be read from the disk every time it is requested.
This setting saves RAM by (slightly) increasing the response time.
Note: those payload values that are involved in filtering and are
indexed - remain in RAM.
force_recreate: force_recreate:
Force recreating the collection Force recreating the collection
**kwargs: **kwargs:
@ -1588,8 +1631,8 @@ class Qdrant(VectorStore):
url: Optional[str] = None, url: Optional[str] = None,
port: Optional[int] = 6333, port: Optional[int] = 6333,
grpc_port: int = 6334, grpc_port: int = 6334,
prefer_grpc: bool = False, prefer_grpc: bool = False, # noqa: FBT001, FBT002
https: Optional[bool] = None, https: Optional[bool] = None, # noqa: FBT001
api_key: Optional[str] = None, api_key: Optional[str] = None,
prefix: Optional[str] = None, prefix: Optional[str] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
@ -1603,14 +1646,14 @@ class Qdrant(VectorStore):
shard_number: Optional[int] = None, shard_number: Optional[int] = None,
replication_factor: Optional[int] = None, replication_factor: Optional[int] = None,
write_consistency_factor: Optional[int] = None, write_consistency_factor: Optional[int] = None,
on_disk_payload: Optional[bool] = None, on_disk_payload: Optional[bool] = None, # noqa: FBT001
hnsw_config: Optional[models.HnswConfigDiff] = None, hnsw_config: Optional[models.HnswConfigDiff] = None,
optimizers_config: Optional[models.OptimizersConfigDiff] = None, optimizers_config: Optional[models.OptimizersConfigDiff] = None,
wal_config: Optional[models.WalConfigDiff] = None, wal_config: Optional[models.WalConfigDiff] = None,
quantization_config: Optional[models.QuantizationConfig] = None, quantization_config: Optional[models.QuantizationConfig] = None,
init_from: Optional[models.InitFrom] = None, init_from: Optional[models.InitFrom] = None,
on_disk: Optional[bool] = None, on_disk: Optional[bool] = None, # noqa: FBT001
force_recreate: bool = False, force_recreate: bool = False, # noqa: FBT001, FBT002
**kwargs: Any, **kwargs: Any,
) -> Qdrant: ) -> Qdrant:
# Just do a single quick embedding to get vector size # Just do a single quick embedding to get vector size
@ -1646,16 +1689,17 @@ class Qdrant(VectorStore):
current_vector_config = collection_info.config.params.vectors current_vector_config = collection_info.config.params.vectors
if isinstance(current_vector_config, dict) and vector_name is not None: if isinstance(current_vector_config, dict) and vector_name is not None:
if vector_name not in current_vector_config: if vector_name not in current_vector_config:
raise QdrantException( msg = (
f"Existing Qdrant collection {collection_name} does not " f"Existing Qdrant collection {collection_name} does not "
f"contain vector named {vector_name}. Did you mean one of the " f"contain vector named {vector_name}. Did you mean one of the "
f"existing vectors: {', '.join(current_vector_config.keys())}? " f"existing vectors: {', '.join(current_vector_config.keys())}? "
f"If you want to recreate the collection, set `force_recreate` " f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`." f"parameter to `True`."
) )
raise QdrantException(msg)
current_vector_config = current_vector_config.get(vector_name) # type: ignore[assignment] current_vector_config = current_vector_config.get(vector_name) # type: ignore[assignment]
elif isinstance(current_vector_config, dict) and vector_name is None: elif isinstance(current_vector_config, dict) and vector_name is None:
raise QdrantException( msg = (
f"Existing Qdrant collection {collection_name} uses named vectors. " f"Existing Qdrant collection {collection_name} uses named vectors. "
f"If you want to reuse it, please set `vector_name` to any of the " f"If you want to reuse it, please set `vector_name` to any of the "
f"existing named vectors: " f"existing named vectors: "
@ -1663,35 +1707,39 @@ class Qdrant(VectorStore):
f"If you want to recreate the collection, set `force_recreate` " f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`." f"parameter to `True`."
) )
raise QdrantException(msg)
elif ( elif (
not isinstance(current_vector_config, dict) and vector_name is not None not isinstance(current_vector_config, dict) and vector_name is not None
): ):
raise QdrantException( msg = (
f"Existing Qdrant collection {collection_name} doesn't use named " f"Existing Qdrant collection {collection_name} doesn't use named "
f"vectors. If you want to reuse it, please set `vector_name` to " f"vectors. If you want to reuse it, please set `vector_name` to "
f"`None`. If you want to recreate the collection, set " f"`None`. If you want to recreate the collection, set "
f"`force_recreate` parameter to `True`." f"`force_recreate` parameter to `True`."
) )
raise QdrantException(msg)
if not isinstance(current_vector_config, models.VectorParams): if not isinstance(current_vector_config, models.VectorParams):
raise ValueError( msg = (
"Expected current_vector_config to be an instance of " "Expected current_vector_config to be an instance of "
f"models.VectorParams, but got {type(current_vector_config)}" f"models.VectorParams, but got {type(current_vector_config)}"
) )
raise ValueError(msg)
# Check if the vector configuration has the same dimensionality. # Check if the vector configuration has the same dimensionality.
if current_vector_config.size != vector_size: if current_vector_config.size != vector_size:
raise QdrantException( msg = (
f"Existing Qdrant collection is configured for vectors with " f"Existing Qdrant collection is configured for vectors with "
f"{current_vector_config.size} " f"{current_vector_config.size} "
f"dimensions. Selected embeddings are {vector_size}-dimensional. " f"dimensions. Selected embeddings are {vector_size}-dimensional. "
f"If you want to recreate the collection, set `force_recreate` " f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`." f"parameter to `True`."
) )
raise QdrantException(msg)
current_distance_func = ( current_distance_func = (
current_vector_config.distance.name.upper() # type: ignore[union-attr] current_vector_config.distance.name.upper() # type: ignore[union-attr]
) )
if current_distance_func != distance_func: if current_distance_func != distance_func:
raise QdrantException( msg = (
f"Existing Qdrant collection is configured for " f"Existing Qdrant collection is configured for "
f"{current_distance_func} similarity, but requested " f"{current_distance_func} similarity, but requested "
f"{distance_func}. Please set `distance_func` parameter to " f"{distance_func}. Please set `distance_func` parameter to "
@ -1699,6 +1747,7 @@ class Qdrant(VectorStore):
f"If you want to recreate the collection, set `force_recreate` " f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`." f"parameter to `True`."
) )
raise QdrantException(msg)
else: else:
vectors_config = models.VectorParams( vectors_config = models.VectorParams(
size=vector_size, size=vector_size,
@ -1727,7 +1776,7 @@ class Qdrant(VectorStore):
init_from=init_from, init_from=init_from,
timeout=timeout, # type: ignore[arg-type] timeout=timeout, # type: ignore[arg-type]
) )
qdrant = cls( return cls(
client=client, client=client,
collection_name=collection_name, collection_name=collection_name,
embeddings=embedding, embeddings=embedding,
@ -1737,7 +1786,6 @@ class Qdrant(VectorStore):
vector_name=vector_name, vector_name=vector_name,
async_client=async_client, async_client=async_client,
) )
return qdrant
@classmethod @classmethod
async def aconstruct_instance( async def aconstruct_instance(
@ -1748,8 +1796,8 @@ class Qdrant(VectorStore):
url: Optional[str] = None, url: Optional[str] = None,
port: Optional[int] = 6333, port: Optional[int] = 6333,
grpc_port: int = 6334, grpc_port: int = 6334,
prefer_grpc: bool = False, prefer_grpc: bool = False, # noqa: FBT001, FBT002
https: Optional[bool] = None, https: Optional[bool] = None, # noqa: FBT001
api_key: Optional[str] = None, api_key: Optional[str] = None,
prefix: Optional[str] = None, prefix: Optional[str] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
@ -1763,14 +1811,14 @@ class Qdrant(VectorStore):
shard_number: Optional[int] = None, shard_number: Optional[int] = None,
replication_factor: Optional[int] = None, replication_factor: Optional[int] = None,
write_consistency_factor: Optional[int] = None, write_consistency_factor: Optional[int] = None,
on_disk_payload: Optional[bool] = None, on_disk_payload: Optional[bool] = None, # noqa: FBT001
hnsw_config: Optional[models.HnswConfigDiff] = None, hnsw_config: Optional[models.HnswConfigDiff] = None,
optimizers_config: Optional[models.OptimizersConfigDiff] = None, optimizers_config: Optional[models.OptimizersConfigDiff] = None,
wal_config: Optional[models.WalConfigDiff] = None, wal_config: Optional[models.WalConfigDiff] = None,
quantization_config: Optional[models.QuantizationConfig] = None, quantization_config: Optional[models.QuantizationConfig] = None,
init_from: Optional[models.InitFrom] = None, init_from: Optional[models.InitFrom] = None,
on_disk: Optional[bool] = None, on_disk: Optional[bool] = None, # noqa: FBT001
force_recreate: bool = False, force_recreate: bool = False, # noqa: FBT001, FBT002
**kwargs: Any, **kwargs: Any,
) -> Qdrant: ) -> Qdrant:
# Just do a single quick embedding to get vector size # Just do a single quick embedding to get vector size
@ -1807,16 +1855,17 @@ class Qdrant(VectorStore):
current_vector_config = collection_info.config.params.vectors current_vector_config = collection_info.config.params.vectors
if isinstance(current_vector_config, dict) and vector_name is not None: if isinstance(current_vector_config, dict) and vector_name is not None:
if vector_name not in current_vector_config: if vector_name not in current_vector_config:
raise QdrantException( msg = (
f"Existing Qdrant collection {collection_name} does not " f"Existing Qdrant collection {collection_name} does not "
f"contain vector named {vector_name}. Did you mean one of the " f"contain vector named {vector_name}. Did you mean one of the "
f"existing vectors: {', '.join(current_vector_config.keys())}? " f"existing vectors: {', '.join(current_vector_config.keys())}? "
f"If you want to recreate the collection, set `force_recreate` " f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`." f"parameter to `True`."
) )
raise QdrantException(msg)
current_vector_config = current_vector_config.get(vector_name) # type: ignore[assignment] current_vector_config = current_vector_config.get(vector_name) # type: ignore[assignment]
elif isinstance(current_vector_config, dict) and vector_name is None: elif isinstance(current_vector_config, dict) and vector_name is None:
raise QdrantException( msg = (
f"Existing Qdrant collection {collection_name} uses named vectors. " f"Existing Qdrant collection {collection_name} uses named vectors. "
f"If you want to reuse it, please set `vector_name` to any of the " f"If you want to reuse it, please set `vector_name` to any of the "
f"existing named vectors: " f"existing named vectors: "
@ -1824,36 +1873,40 @@ class Qdrant(VectorStore):
f"If you want to recreate the collection, set `force_recreate` " f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`." f"parameter to `True`."
) )
raise QdrantException(msg)
elif ( elif (
not isinstance(current_vector_config, dict) and vector_name is not None not isinstance(current_vector_config, dict) and vector_name is not None
): ):
raise QdrantException( msg = (
f"Existing Qdrant collection {collection_name} doesn't use named " f"Existing Qdrant collection {collection_name} doesn't use named "
f"vectors. If you want to reuse it, please set `vector_name` to " f"vectors. If you want to reuse it, please set `vector_name` to "
f"`None`. If you want to recreate the collection, set " f"`None`. If you want to recreate the collection, set "
f"`force_recreate` parameter to `True`." f"`force_recreate` parameter to `True`."
) )
raise QdrantException(msg)
if not isinstance(current_vector_config, models.VectorParams): if not isinstance(current_vector_config, models.VectorParams):
raise ValueError( msg = (
"Expected current_vector_config to be an instance of " "Expected current_vector_config to be an instance of "
f"models.VectorParams, but got {type(current_vector_config)}" f"models.VectorParams, but got {type(current_vector_config)}"
) )
raise ValueError(msg)
# Check if the vector configuration has the same dimensionality. # Check if the vector configuration has the same dimensionality.
if current_vector_config.size != vector_size: if current_vector_config.size != vector_size:
raise QdrantException( msg = (
f"Existing Qdrant collection is configured for vectors with " f"Existing Qdrant collection is configured for vectors with "
f"{current_vector_config.size} " f"{current_vector_config.size} "
f"dimensions. Selected embeddings are {vector_size}-dimensional. " f"dimensions. Selected embeddings are {vector_size}-dimensional. "
f"If you want to recreate the collection, set `force_recreate` " f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`." f"parameter to `True`."
) )
raise QdrantException(msg)
current_distance_func = ( current_distance_func = (
current_vector_config.distance.name.upper() # type: ignore[union-attr] current_vector_config.distance.name.upper() # type: ignore[union-attr]
) )
if current_distance_func != distance_func: if current_distance_func != distance_func:
raise QdrantException( msg = (
f"Existing Qdrant collection is configured for " f"Existing Qdrant collection is configured for "
f"{current_vector_config.distance} " # type: ignore[union-attr] f"{current_vector_config.distance} " # type: ignore[union-attr]
f"similarity. Please set `distance_func` parameter to " f"similarity. Please set `distance_func` parameter to "
@ -1861,6 +1914,7 @@ class Qdrant(VectorStore):
f"recreate the collection, set `force_recreate` parameter to " f"recreate the collection, set `force_recreate` parameter to "
f"`True`." f"`True`."
) )
raise QdrantException(msg)
else: else:
vectors_config = models.VectorParams( vectors_config = models.VectorParams(
size=vector_size, size=vector_size,
@ -1889,7 +1943,7 @@ class Qdrant(VectorStore):
init_from=init_from, init_from=init_from,
timeout=timeout, # type: ignore[arg-type] timeout=timeout, # type: ignore[arg-type]
) )
qdrant = cls( return cls(
client=client, client=client,
collection_name=collection_name, collection_name=collection_name,
embeddings=embedding, embeddings=embedding,
@ -1899,7 +1953,6 @@ class Qdrant(VectorStore):
vector_name=vector_name, vector_name=vector_name,
async_client=async_client, async_client=async_client,
) )
return qdrant
@staticmethod @staticmethod
def _cosine_relevance_score_fn(distance: float) -> float: def _cosine_relevance_score_fn(distance: float) -> float:
@ -1907,26 +1960,24 @@ class Qdrant(VectorStore):
return (distance + 1.0) / 2.0 return (distance + 1.0) / 2.0
def _select_relevance_score_fn(self) -> Callable[[float], float]: def _select_relevance_score_fn(self) -> Callable[[float], float]:
""" """Your 'correct' relevance function may differ depending on a few things.
The 'correct' relevance function
may differ depending on a few things, including: For example:
- the distance / similarity metric used by the VectorStore - The distance / similarity metric used by the VectorStore
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!) - The scale of your embeddings (OpenAI's are unit normed. Many others are not!)
- embedding dimensionality - Embedding dimensionality
- etc. - etc.
""" """
if self.distance_strategy == "COSINE": if self.distance_strategy == "COSINE":
return self._cosine_relevance_score_fn return self._cosine_relevance_score_fn
elif self.distance_strategy == "DOT": if self.distance_strategy == "DOT":
return self._max_inner_product_relevance_score_fn return self._max_inner_product_relevance_score_fn
elif self.distance_strategy == "EUCLID": if self.distance_strategy == "EUCLID":
return self._euclidean_relevance_score_fn return self._euclidean_relevance_score_fn
else: msg = (
raise ValueError( "Unknown distance strategy, must be cosine, max_inner_product, or euclidean"
"Unknown distance strategy, must be cosine, " )
"max_inner_product, or euclidean" raise ValueError(msg)
)
def _similarity_search_with_relevance_scores( def _similarity_search_with_relevance_scores(
self, self,
@ -1947,6 +1998,7 @@ class Qdrant(VectorStore):
Returns: Returns:
List of Tuples of (doc, similarity_score) List of Tuples of (doc, similarity_score)
""" """
return self.similarity_search_with_score(query, k, **kwargs) return self.similarity_search_with_score(query, k, **kwargs)
@ -1970,6 +2022,7 @@ class Qdrant(VectorStore):
Returns: Returns:
List of Tuples of (doc, similarity_score) List of Tuples of (doc, similarity_score)
""" """
return await self.asimilarity_search_with_score(query, k, **kwargs) return await self.asimilarity_search_with_score(query, k, **kwargs)
@ -1984,10 +2037,11 @@ class Qdrant(VectorStore):
payloads = [] payloads = []
for i, text in enumerate(texts): for i, text in enumerate(texts):
if text is None: if text is None:
raise ValueError( msg = (
"At least one of the texts is None. Please remove it before " "At least one of the texts is None. Please remove it before "
"calling .from_texts or .add_texts on Qdrant instance." "calling .from_texts or .add_texts on Qdrant instance."
) )
raise ValueError(msg)
metadata = metadatas[i] if metadatas is not None else None metadata = metadatas[i] if metadatas is not None else None
payloads.append( payloads.append(
{ {
@ -2018,8 +2072,8 @@ class Qdrant(VectorStore):
out = [] out = []
if isinstance(value, dict): if isinstance(value, dict):
for _key, value in value.items(): for _key, _value in value.items():
out.extend(self._build_condition(f"{key}.{_key}", value)) out.extend(self._build_condition(f"{key}.{_key}", _value))
elif isinstance(value, list): elif isinstance(value, list):
for _value in value: for _value in value:
if isinstance(_value, dict): if isinstance(_value, dict):
@ -2037,15 +2091,15 @@ class Qdrant(VectorStore):
return out return out
def _qdrant_filter_from_dict( def _qdrant_filter_from_dict(
self, filter: Optional[DictFilter] self, filter_: Optional[DictFilter]
) -> Optional[models.Filter]: ) -> Optional[models.Filter]:
if not filter: if not filter_:
return None return None
return models.Filter( return models.Filter(
must=[ must=[
condition condition
for key, value in filter.items() for key, value in filter_.items() # type: ignore[union-attr]
for condition in self._build_condition(key, value) for condition in self._build_condition(key, value)
] ]
) )
@ -2060,14 +2114,15 @@ class Qdrant(VectorStore):
Returns: Returns:
List of floats representing the query embedding. List of floats representing the query embedding.
""" """
if self.embeddings is not None: if self.embeddings is not None:
embedding = self.embeddings.embed_query(query) embedding = self.embeddings.embed_query(query)
elif self._embeddings_function is not None:
embedding = self._embeddings_function(query)
else: else:
if self._embeddings_function is not None: msg = "Neither of embeddings or embedding_function is set"
embedding = self._embeddings_function(query) raise ValueError(msg)
else:
raise ValueError("Neither of embeddings or embedding_function is set")
return embedding.tolist() if hasattr(embedding, "tolist") else embedding return embedding.tolist() if hasattr(embedding, "tolist") else embedding
async def _aembed_query(self, query: str) -> list[float]: async def _aembed_query(self, query: str) -> list[float]:
@ -2080,14 +2135,15 @@ class Qdrant(VectorStore):
Returns: Returns:
List of floats representing the query embedding. List of floats representing the query embedding.
""" """
if self.embeddings is not None: if self.embeddings is not None:
embedding = await self.embeddings.aembed_query(query) embedding = await self.embeddings.aembed_query(query)
elif self._embeddings_function is not None:
embedding = self._embeddings_function(query)
else: else:
if self._embeddings_function is not None: msg = "Neither of embeddings or embedding_function is set"
embedding = self._embeddings_function(query) raise ValueError(msg)
else:
raise ValueError("Neither of embeddings or embedding_function is set")
return embedding.tolist() if hasattr(embedding, "tolist") else embedding return embedding.tolist() if hasattr(embedding, "tolist") else embedding
def _embed_texts(self, texts: Iterable[str]) -> list[list[float]]: def _embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
@ -2100,6 +2156,7 @@ class Qdrant(VectorStore):
Returns: Returns:
List of floats representing the texts embedding. List of floats representing the texts embedding.
""" """
if self.embeddings is not None: if self.embeddings is not None:
embeddings = self.embeddings.embed_documents(list(texts)) embeddings = self.embeddings.embed_documents(list(texts))
@ -2113,7 +2170,8 @@ class Qdrant(VectorStore):
embedding = embedding.tolist() embedding = embedding.tolist()
embeddings.append(embedding) embeddings.append(embedding)
else: else:
raise ValueError("Neither of embeddings or embedding_function is set") msg = "Neither of embeddings or embedding_function is set"
raise ValueError(msg)
return embeddings return embeddings
@ -2127,6 +2185,7 @@ class Qdrant(VectorStore):
Returns: Returns:
List of floats representing the texts embedding. List of floats representing the texts embedding.
""" """
if self.embeddings is not None: if self.embeddings is not None:
embeddings = await self.embeddings.aembed_documents(list(texts)) embeddings = await self.embeddings.aembed_documents(list(texts))
@ -2140,7 +2199,8 @@ class Qdrant(VectorStore):
embedding = embedding.tolist() embedding = embedding.tolist()
embeddings.append(embedding) embeddings.append(embedding)
else: else:
raise ValueError("Neither of embeddings or embedding_function is set") msg = "Neither of embeddings or embedding_function is set"
raise ValueError(msg)
return embeddings return embeddings
@ -2230,8 +2290,8 @@ class Qdrant(VectorStore):
url: Optional[str] = None, url: Optional[str] = None,
port: Optional[int] = 6333, port: Optional[int] = 6333,
grpc_port: int = 6334, grpc_port: int = 6334,
prefer_grpc: bool = False, prefer_grpc: bool = False, # noqa: FBT001, FBT002
https: Optional[bool] = None, https: Optional[bool] = None, # noqa: FBT001
api_key: Optional[str] = None, api_key: Optional[str] = None,
prefix: Optional[str] = None, prefix: Optional[str] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,

View File

@ -51,8 +51,63 @@ langchain-core = { path = "../../core", editable = true }
target-version = "py39" target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I", "T201", "UP", "S"] select = [
ignore = [ "UP007", ] "A", # flake8-builtins
"B", # flake8-bugbear
"ASYNC", # flake8-async
"C4", # flake8-comprehensions
"COM", # flake8-commas
"D", # pydocstyle
"E", # pycodestyle error
"EM", # flake8-errmsg
"F", # pyflakes
"FA", # flake8-future-annotations
"FBT", # flake8-boolean-trap
"FLY", # flake8-flynt
"I", # isort
"ICN", # flake8-import-conventions
"INT", # flake8-gettext
"ISC", # isort-comprehensions
"PGH", # pygrep-hooks
"PIE", # flake8-pie
"PERF", # flake8-perf
"PYI", # flake8-pyi
"Q", # flake8-quotes
"RET", # flake8-return
"RSE", # flake8-rst-docstrings
"RUF", # ruff
"S", # flake8-bandit
"SLF", # flake8-self
"SLOT", # flake8-slots
"SIM", # flake8-simplify
"T10", # flake8-debugger
"T20", # flake8-print
"TID", # flake8-tidy-imports
"UP", # pyupgrade
"W", # pycodestyle warning
"YTT", # flake8-2020
]
ignore = [
"D100", # pydocstyle: Missing docstring in public module
"D101", # pydocstyle: Missing docstring in public class
"D102", # pydocstyle: Missing docstring in public method
"D103", # pydocstyle: Missing docstring in public function
"D104", # pydocstyle: Missing docstring in public package
"D105", # pydocstyle: Missing docstring in magic method
"D107", # pydocstyle: Missing docstring in __init__
"D203", # Messes with the formatter
"D213", # pydocstyle: Multi-line docstring summary should start at the second line (incompatible with D212)
"D407", # pydocstyle: Missing-dashed-underline-after-section
"COM812", # Messes with the formatter
"ISC001", # Messes with the formatter
"PERF203", # Rarely useful
"S112", # Rarely useful
"RUF012", # Doesn't play well with Pydantic
"SLF001", # Private member access
"UP007", # pyupgrade: non-pep604-annotation-union
"UP045", # pyupgrade: non-pep604-annotation-optional
]
unfixable = ["B028"] # People should intentionally tune the stacklevel
[tool.mypy] [tool.mypy]
disallow_untyped_defs = true disallow_untyped_defs = true

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import os import os
import uuid import uuid
from typing import Optional from typing import Optional
@ -25,8 +27,8 @@ async def test_qdrant_aadd_texts_returns_all_ids(
) )
ids = await docsearch.aadd_texts(["foo", "bar", "baz"]) ids = await docsearch.aadd_texts(["foo", "bar", "baz"])
assert 3 == len(ids) assert len(ids) == 3
assert 3 == len(set(ids)) assert len(set(ids)) == 3
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@ -53,8 +55,8 @@ async def test_qdrant_aadd_texts_stores_duplicated_texts(
) )
ids = await vec_store.aadd_texts(["abc", "abc"], [{"a": 1}, {"a": 2}]) ids = await vec_store.aadd_texts(["abc", "abc"], [{"a": 1}, {"a": 2}])
assert 2 == len(set(ids)) assert len(set(ids)) == 2
assert 2 == client.count(collection_name).count assert client.count(collection_name).count == 2
@pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("batch_size", [1, 64])
@ -84,7 +86,7 @@ async def test_qdrant_aadd_texts_stores_ids(
) )
assert all(first == second for first, second in zip(ids, returned_ids)) assert all(first == second for first, second in zip(ids, returned_ids))
assert 2 == client.count(collection_name).count assert client.count(collection_name).count == 2
stored_ids = [point.id for point in client.scroll(collection_name)[0]] stored_ids = [point.id for point in client.scroll(collection_name)[0]]
assert set(ids) == set(stored_ids) assert set(ids) == set(stored_ids)
@ -116,7 +118,7 @@ async def test_qdrant_aadd_texts_stores_embeddings_as_named_vectors(
) )
await vec_store.aadd_texts(["lorem", "ipsum", "dolor", "sit", "amet"]) await vec_store.aadd_texts(["lorem", "ipsum", "dolor", "sit", "amet"])
assert 5 == client.count(collection_name).count assert client.count(collection_name).count == 5
assert all( assert all(
vector_name in point.vector # type: ignore[operator] vector_name in point.vector # type: ignore[operator]
for point in client.scroll(collection_name, with_vectors=True)[0] for point in client.scroll(collection_name, with_vectors=True)[0]

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import os import os
import uuid import uuid
from typing import Optional from typing import Optional
@ -29,7 +31,7 @@ async def test_qdrant_from_texts_stores_duplicated_texts(qdrant_location: str) -
) )
client = vec_store.client client = vec_store.client
assert 2 == client.count(collection_name).count assert client.count(collection_name).count == 2
@pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("batch_size", [1, 64])
@ -55,7 +57,7 @@ async def test_qdrant_from_texts_stores_ids(
) )
client = vec_store.client client = vec_store.client
assert 2 == client.count(collection_name).count assert client.count(collection_name).count == 2
stored_ids = [point.id for point in client.scroll(collection_name)[0]] stored_ids = [point.id for point in client.scroll(collection_name)[0]]
assert set(ids) == set(stored_ids) assert set(ids) == set(stored_ids)
@ -78,7 +80,7 @@ async def test_qdrant_from_texts_stores_embeddings_as_named_vectors(
) )
client = vec_store.client client = vec_store.client
assert 5 == client.count(collection_name).count assert client.count(collection_name).count == 5
assert all( assert all(
vector_name in point.vector # type: ignore[operator] vector_name in point.vector # type: ignore[operator]
for point in client.scroll(collection_name, with_vectors=True)[0] for point in client.scroll(collection_name, with_vectors=True)[0]
@ -90,7 +92,7 @@ async def test_qdrant_from_texts_stores_embeddings_as_named_vectors(
async def test_qdrant_from_texts_reuses_same_collection( async def test_qdrant_from_texts_reuses_same_collection(
location: str, vector_name: Optional[str] location: str, vector_name: Optional[str]
) -> None: ) -> None:
"""Test if Qdrant.afrom_texts reuses the same collection""" """Test if Qdrant.afrom_texts reuses the same collection."""
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
embeddings = ConsistentFakeEmbeddings() embeddings = ConsistentFakeEmbeddings()
@ -111,7 +113,7 @@ async def test_qdrant_from_texts_reuses_same_collection(
) )
client = vec_store.client client = vec_store.client
assert 7 == client.count(collection_name).count assert client.count(collection_name).count == 7
@pytest.mark.parametrize("location", qdrant_locations(use_in_memory=False)) @pytest.mark.parametrize("location", qdrant_locations(use_in_memory=False))
@ -121,7 +123,8 @@ async def test_qdrant_from_texts_raises_error_on_different_dimensionality(
vector_name: Optional[str], vector_name: Optional[str],
) -> None: ) -> None:
"""Test if Qdrant.afrom_texts raises an exception if dimensionality does not """Test if Qdrant.afrom_texts raises an exception if dimensionality does not
match""" match.
""" # noqa: D205
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
await Qdrant.afrom_texts( await Qdrant.afrom_texts(
@ -156,7 +159,7 @@ async def test_qdrant_from_texts_raises_error_on_different_vector_name(
first_vector_name: Optional[str], first_vector_name: Optional[str],
second_vector_name: Optional[str], second_vector_name: Optional[str],
) -> None: ) -> None:
"""Test if Qdrant.afrom_texts raises an exception if vector name does not match""" """Test if Qdrant.afrom_texts raises an exception if vector name does not match."""
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
await Qdrant.afrom_texts( await Qdrant.afrom_texts(
@ -181,7 +184,7 @@ async def test_qdrant_from_texts_raises_error_on_different_vector_name(
async def test_qdrant_from_texts_raises_error_on_different_distance( async def test_qdrant_from_texts_raises_error_on_different_distance(
location: str, location: str,
) -> None: ) -> None:
"""Test if Qdrant.afrom_texts raises an exception if distance does not match""" """Test if Qdrant.afrom_texts raises an exception if distance does not match."""
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
await Qdrant.afrom_texts( await Qdrant.afrom_texts(
@ -208,7 +211,7 @@ async def test_qdrant_from_texts_recreates_collection_on_force_recreate(
location: str, location: str,
vector_name: Optional[str], vector_name: Optional[str],
) -> None: ) -> None:
"""Test if Qdrant.afrom_texts recreates the collection even if config mismatches""" """Test if Qdrant.afrom_texts recreates the collection even if config mismatches."""
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
@ -231,11 +234,11 @@ async def test_qdrant_from_texts_recreates_collection_on_force_recreate(
) )
client = QdrantClient(location=location, api_key=os.getenv("QDRANT_API_KEY")) client = QdrantClient(location=location, api_key=os.getenv("QDRANT_API_KEY"))
assert 2 == client.count(collection_name).count assert client.count(collection_name).count == 2
vector_params = client.get_collection(collection_name).config.params.vectors vector_params = client.get_collection(collection_name).config.params.vectors
if vector_name is not None: if vector_name is not None:
vector_params = vector_params[vector_name] # type: ignore[index] vector_params = vector_params[vector_name] # type: ignore[index]
assert 5 == vector_params.size # type: ignore[union-attr] assert vector_params.size == 5 # type: ignore[union-attr]
@pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("batch_size", [1, 64])

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Optional from typing import Optional
import pytest # type: ignore[import-not-found] import pytest # type: ignore[import-not-found]

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Optional from typing import Optional
import numpy as np import numpy as np
@ -187,7 +189,7 @@ async def test_qdrant_similarity_search_with_relevance_score_with_threshold(
"foo", k=3, **kwargs "foo", k=3, **kwargs
) )
assert len(output) == 1 assert len(output) == 1
assert all([score >= score_threshold for _, score in output]) assert all(score >= score_threshold for _, score in output)
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@ -222,7 +224,7 @@ async def test_similarity_search_with_relevance_score_with_threshold_and_filter(
"foo", k=3, **kwargs "foo", k=3, **kwargs
) )
assert len(output) == 1 assert len(output) == 1
assert all([score >= score_threshold for _, score in output]) assert all(score >= score_threshold for _, score in output)
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@ -301,5 +303,5 @@ async def test_qdrant_similarity_search_with_relevance_scores(
output = await docsearch.asimilarity_search_with_relevance_scores("foo", k=3) output = await docsearch.asimilarity_search_with_relevance_scores("foo", k=3)
assert all( assert all(
(1 >= score or np.isclose(score, 1)) and score >= 0 for _, score in output (score <= 1 or np.isclose(score, 1)) and score >= 0 for _, score in output
) )

View File

@ -1,4 +1,4 @@
import requests # type: ignore import requests # type: ignore[import-untyped]
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
@ -7,7 +7,6 @@ from langchain_qdrant import SparseEmbeddings, SparseVector
def qdrant_running_locally() -> bool: def qdrant_running_locally() -> bool:
"""Check if Qdrant is running at http://localhost:6333.""" """Check if Qdrant is running at http://localhost:6333."""
try: try:
response = requests.get("http://localhost:6333", timeout=10.0) response = requests.get("http://localhost:6333", timeout=10.0)
response_json = response.json() response_json = response.json()
@ -33,7 +32,8 @@ def assert_documents_equals(actual: list[Document], expected: list[Document]):
class ConsistentFakeEmbeddings(Embeddings): class ConsistentFakeEmbeddings(Embeddings):
"""Fake embeddings which remember all the texts seen so far to return consistent """Fake embeddings which remember all the texts seen so far to return consistent
vectors for the same texts.""" vectors for the same texts.
""" # noqa: D205
def __init__(self, dimensionality: int = 10) -> None: def __init__(self, dimensionality: int = 10) -> None:
self.known_texts: list[str] = [] self.known_texts: list[str] = []
@ -53,13 +53,15 @@ class ConsistentFakeEmbeddings(Embeddings):
def embed_query(self, text: str) -> list[float]: def embed_query(self, text: str) -> list[float]:
"""Return consistent embeddings for the text, if seen before, or a constant """Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown.""" one if the text is unknown.
""" # noqa: D205
return self.embed_documents([text])[0] return self.embed_documents([text])[0]
class ConsistentFakeSparseEmbeddings(SparseEmbeddings): class ConsistentFakeSparseEmbeddings(SparseEmbeddings):
"""Fake sparse embeddings which remembers all the texts seen so far " """Fake sparse embeddings which remembers all the texts seen so far
"to return consistent vectors for the same texts.""" "to return consistent vectors for the same texts.
""" # noqa: D205
def __init__(self, dimensionality: int = 25) -> None: def __init__(self, dimensionality: int = 25) -> None:
self.known_texts: list[str] = [] self.known_texts: list[str] = []
@ -78,6 +80,7 @@ class ConsistentFakeSparseEmbeddings(SparseEmbeddings):
return out_vectors return out_vectors
def embed_query(self, text: str) -> SparseVector: def embed_query(self, text: str) -> SparseVector:
"""Return consistent embeddings for the text, " """Return consistent embeddings for the text, if seen before, or a constant
"if seen before, or a constant one if the text is unknown.""" one if the text is unknown.
""" # noqa: D205
return self.embed_documents([text])[0] return self.embed_documents([text])[0]

View File

@ -7,7 +7,7 @@ from tests.integration_tests.common import qdrant_running_locally
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def qdrant_locations(use_in_memory: bool = True) -> list[str]: def qdrant_locations(use_in_memory: bool = True) -> list[str]: # noqa: FBT001, FBT002
locations = [] locations = []
if use_in_memory: if use_in_memory:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import uuid import uuid
from typing import Union from typing import Union
@ -71,9 +73,9 @@ def test_qdrant_add_texts_returns_all_ids(
) )
ids = docsearch.add_texts(["foo", "bar", "baz"]) ids = docsearch.add_texts(["foo", "bar", "baz"])
assert 3 == len(ids) assert len(ids) == 3
assert 3 == len(set(ids)) assert len(set(ids)) == 3
assert 3 == len(docsearch.get_by_ids(ids)) assert len(docsearch.get_by_ids(ids)) == 3
@pytest.mark.parametrize("location", qdrant_locations()) @pytest.mark.parametrize("location", qdrant_locations())
@ -83,7 +85,6 @@ def test_qdrant_add_texts_stores_duplicated_texts(
vector_name: str, vector_name: str,
) -> None: ) -> None:
"""Test end to end Qdrant.add_texts stores duplicated texts separately.""" """Test end to end Qdrant.add_texts stores duplicated texts separately."""
client = QdrantClient(location) client = QdrantClient(location)
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
vectors_config = { vectors_config = {
@ -99,8 +100,8 @@ def test_qdrant_add_texts_stores_duplicated_texts(
) )
ids = vec_store.add_texts(["abc", "abc"], [{"a": 1}, {"a": 2}]) ids = vec_store.add_texts(["abc", "abc"], [{"a": 1}, {"a": 2}])
assert 2 == len(set(ids)) assert len(set(ids)) == 2
assert 2 == client.count(collection_name).count assert client.count(collection_name).count == 2
@pytest.mark.parametrize("location", qdrant_locations()) @pytest.mark.parametrize("location", qdrant_locations())
@ -137,7 +138,7 @@ def test_qdrant_add_texts_stores_ids(
batch_size=batch_size, batch_size=batch_size,
) )
assert 3 == vec_store.client.count(collection_name).count assert vec_store.client.count(collection_name).count == 3
stored_ids = [point.id for point in vec_store.client.scroll(collection_name)[0]] stored_ids = [point.id for point in vec_store.client.scroll(collection_name)[0]]
assert set(ids) == set(stored_ids) assert set(ids) == set(stored_ids)
assert 3 == len(vec_store.get_by_ids(ids)) assert len(vec_store.get_by_ids(ids)) == 3

View File

@ -23,7 +23,6 @@ def test_qdrant_from_existing_collection_uses_same_collection(
sparse_vector_name: str, sparse_vector_name: str,
) -> None: ) -> None:
"""Test if the QdrantVectorStore.from_existing_collection reuses the collection.""" """Test if the QdrantVectorStore.from_existing_collection reuses the collection."""
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
docs = ["foo"] docs = ["foo"]
QdrantVectorStore.from_texts( QdrantVectorStore.from_texts(
@ -48,4 +47,4 @@ def test_qdrant_from_existing_collection_uses_same_collection(
) )
qdrant.add_texts(["baz", "bar"]) qdrant.add_texts(["baz", "bar"])
assert 3 == qdrant.client.count(collection_name).count assert qdrant.client.count(collection_name).count == 3

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import uuid import uuid
from typing import Union from typing import Union
@ -30,7 +32,7 @@ def test_vectorstore_from_texts(location: str, retrieval_mode: RetrievalMode) ->
sparse_embedding=ConsistentFakeSparseEmbeddings(), sparse_embedding=ConsistentFakeSparseEmbeddings(),
) )
assert 2 == vec_store.client.count(collection_name).count assert vec_store.client.count(collection_name).count == 2
@pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("batch_size", [1, 64])
@ -66,7 +68,7 @@ def test_qdrant_from_texts_stores_ids(
sparse_vector_name=sparse_vector_name, sparse_vector_name=sparse_vector_name,
) )
assert 2 == vec_store.client.count(collection_name).count assert vec_store.client.count(collection_name).count == 2
stored_ids = [point.id for point in vec_store.client.retrieve(collection_name, ids)] stored_ids = [point.id for point in vec_store.client.retrieve(collection_name, ids)]
assert set(ids) == set(stored_ids) assert set(ids) == set(stored_ids)
@ -84,7 +86,6 @@ def test_qdrant_from_texts_stores_embeddings_as_named_vectors(
sparse_vector_name: str, sparse_vector_name: str,
) -> None: ) -> None:
"""Test end to end Qdrant.from_texts stores named vectors if name is provided.""" """Test end to end Qdrant.from_texts stores named vectors if name is provided."""
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
vec_store = QdrantVectorStore.from_texts( vec_store = QdrantVectorStore.from_texts(
["lorem", "ipsum", "dolor", "sit", "amet"], ["lorem", "ipsum", "dolor", "sit", "amet"],
@ -97,15 +98,15 @@ def test_qdrant_from_texts_stores_embeddings_as_named_vectors(
sparse_embedding=ConsistentFakeSparseEmbeddings(), sparse_embedding=ConsistentFakeSparseEmbeddings(),
) )
assert 5 == vec_store.client.count(collection_name).count assert vec_store.client.count(collection_name).count == 5
if retrieval_mode in retrieval_modes(sparse=False): if retrieval_mode in retrieval_modes(sparse=False):
assert all( assert all(
(vector_name in point.vector or isinstance(point.vector, list)) # type: ignore (vector_name in point.vector or isinstance(point.vector, list)) # type: ignore[operator]
for point in vec_store.client.scroll(collection_name, with_vectors=True)[0] for point in vec_store.client.scroll(collection_name, with_vectors=True)[0]
) )
if retrieval_mode in retrieval_modes(dense=False): if retrieval_mode in retrieval_modes(dense=False):
assert all( assert all(
sparse_vector_name in point.vector # type: ignore sparse_vector_name in point.vector # type: ignore[operator]
for point in vec_store.client.scroll(collection_name, with_vectors=True)[0] for point in vec_store.client.scroll(collection_name, with_vectors=True)[0]
) )
@ -122,7 +123,7 @@ def test_qdrant_from_texts_reuses_same_collection(
vector_name: str, vector_name: str,
sparse_vector_name: str, sparse_vector_name: str,
) -> None: ) -> None:
"""Test if Qdrant.from_texts reuses the same collection""" """Test if Qdrant.from_texts reuses the same collection."""
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
embeddings = ConsistentFakeEmbeddings() embeddings = ConsistentFakeEmbeddings()
sparse_embeddings = ConsistentFakeSparseEmbeddings() sparse_embeddings = ConsistentFakeSparseEmbeddings()
@ -149,7 +150,7 @@ def test_qdrant_from_texts_reuses_same_collection(
sparse_embedding=sparse_embeddings, sparse_embedding=sparse_embeddings,
) )
assert 7 == vec_store.client.count(collection_name).count assert vec_store.client.count(collection_name).count == 7
@pytest.mark.parametrize("location", qdrant_locations(use_in_memory=False)) @pytest.mark.parametrize("location", qdrant_locations(use_in_memory=False))
@ -160,7 +161,7 @@ def test_qdrant_from_texts_raises_error_on_different_dimensionality(
vector_name: str, vector_name: str,
retrieval_mode: RetrievalMode, retrieval_mode: RetrievalMode,
) -> None: ) -> None:
"""Test if Qdrant.from_texts raises an exception if dimensionality does not match""" """Test if Qdrant.from_texts raises an exception if dimensionality doesn't match."""
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
QdrantVectorStore.from_texts( QdrantVectorStore.from_texts(
["lorem", "ipsum", "dolor", "sit", "amet"], ["lorem", "ipsum", "dolor", "sit", "amet"],
@ -204,7 +205,7 @@ def test_qdrant_from_texts_raises_error_on_different_vector_name(
second_vector_name: str, second_vector_name: str,
retrieval_mode: RetrievalMode, retrieval_mode: RetrievalMode,
) -> None: ) -> None:
"""Test if Qdrant.from_texts raises an exception if vector name does not match""" """Test if Qdrant.from_texts raises an exception if vector name does not match."""
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
QdrantVectorStore.from_texts( QdrantVectorStore.from_texts(
["lorem", "ipsum", "dolor", "sit", "amet"], ["lorem", "ipsum", "dolor", "sit", "amet"],
@ -237,7 +238,7 @@ def test_qdrant_from_texts_raises_error_on_different_vector_name(
def test_qdrant_from_texts_raises_error_on_different_distance( def test_qdrant_from_texts_raises_error_on_different_distance(
location: str, vector_name: str, retrieval_mode: RetrievalMode location: str, vector_name: str, retrieval_mode: RetrievalMode
) -> None: ) -> None:
"""Test if Qdrant.from_texts raises an exception if distance does not match""" """Test if Qdrant.from_texts raises an exception if distance does not match."""
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
QdrantVectorStore.from_texts( QdrantVectorStore.from_texts(
["lorem", "ipsum", "dolor", "sit", "amet"], ["lorem", "ipsum", "dolor", "sit", "amet"],
@ -302,7 +303,7 @@ def test_qdrant_from_texts_recreates_collection_on_force_recreate(
force_recreate=True, force_recreate=True,
) )
assert 2 == vec_store.client.count(collection_name).count assert vec_store.client.count(collection_name).count == 2
@pytest.mark.parametrize("location", qdrant_locations()) @pytest.mark.parametrize("location", qdrant_locations())
@ -380,6 +381,6 @@ def test_from_texts_passed_optimizers_config_and_on_disk_payload(
) )
collection_info = vec_store.client.get_collection(collection_name) collection_info = vec_store.client.get_collection(collection_name)
assert collection_info.config.params.vectors[vector_name].on_disk is True # type: ignore assert collection_info.config.params.vectors[vector_name].on_disk is True # type: ignore[index]
assert collection_info.config.optimizer_config.memmap_threshold == 1000 assert collection_info.config.optimizer_config.memmap_threshold == 1000
assert collection_info.config.params.on_disk_payload is True assert collection_info.config.params.on_disk_payload is True

View File

@ -31,7 +31,7 @@ def test_qdrant_mmr_search(
vector_name: str, vector_name: str,
) -> None: ) -> None:
"""Test end to end construction and MRR search.""" """Test end to end construction and MRR search."""
filter = models.Filter( filter_ = models.Filter(
must=[ must=[
models.FieldCondition( models.FieldCondition(
key=f"{metadata_payload_key}.page", key=f"{metadata_payload_key}.page",
@ -68,7 +68,7 @@ def test_qdrant_mmr_search(
) )
output = docsearch.max_marginal_relevance_search( output = docsearch.max_marginal_relevance_search(
"foo", k=2, fetch_k=3, lambda_mult=0.0, filter=filter "foo", k=2, fetch_k=3, lambda_mult=0.0, filter=filter_
) )
assert_documents_equals( assert_documents_equals(
output, output,

View File

@ -197,7 +197,7 @@ def test_relevance_search_with_threshold(
kwargs = {"score_threshold": score_threshold} kwargs = {"score_threshold": score_threshold}
output = docsearch.similarity_search_with_relevance_scores("foo", k=3, **kwargs) output = docsearch.similarity_search_with_relevance_scores("foo", k=3, **kwargs)
assert len(output) == 1 assert len(output) == 1
assert all([score >= score_threshold for _, score in output]) assert all(score >= score_threshold for _, score in output)
@pytest.mark.parametrize("location", qdrant_locations()) @pytest.mark.parametrize("location", qdrant_locations())
@ -248,7 +248,7 @@ def test_relevance_search_with_threshold_and_filter(
kwargs = {"filter": positive_filter, "score_threshold": score_threshold} kwargs = {"filter": positive_filter, "score_threshold": score_threshold}
output = docsearch.similarity_search_with_relevance_scores("foo", k=3, **kwargs) output = docsearch.similarity_search_with_relevance_scores("foo", k=3, **kwargs)
assert len(output) == 1 assert len(output) == 1
assert all([score >= score_threshold for _, score in output]) assert all(score >= score_threshold for _, score in output)
@pytest.mark.parametrize("location", qdrant_locations()) @pytest.mark.parametrize("location", qdrant_locations())

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import uuid import uuid
from typing import Optional from typing import Optional
@ -48,8 +50,8 @@ def test_qdrant_add_texts_returns_all_ids(batch_size: int) -> None:
) )
ids = docsearch.add_texts(["foo", "bar", "baz"]) ids = docsearch.add_texts(["foo", "bar", "baz"])
assert 3 == len(ids) assert len(ids) == 3
assert 3 == len(set(ids)) assert len(set(ids)) == 3
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@ -73,8 +75,8 @@ def test_qdrant_add_texts_stores_duplicated_texts(vector_name: Optional[str]) ->
) )
ids = vec_store.add_texts(["abc", "abc"], [{"a": 1}, {"a": 2}]) ids = vec_store.add_texts(["abc", "abc"], [{"a": 1}, {"a": 2}])
assert 2 == len(set(ids)) assert len(set(ids)) == 2
assert 2 == client.count(collection_name).count assert client.count(collection_name).count == 2
@pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("batch_size", [1, 64])
@ -99,7 +101,7 @@ def test_qdrant_add_texts_stores_ids(batch_size: int) -> None:
returned_ids = vec_store.add_texts(["abc", "def"], ids=ids, batch_size=batch_size) returned_ids = vec_store.add_texts(["abc", "def"], ids=ids, batch_size=batch_size)
assert all(first == second for first, second in zip(ids, returned_ids)) assert all(first == second for first, second in zip(ids, returned_ids))
assert 2 == client.count(collection_name).count assert client.count(collection_name).count == 2
stored_ids = [point.id for point in client.scroll(collection_name)[0]] stored_ids = [point.id for point in client.scroll(collection_name)[0]]
assert set(ids) == set(stored_ids) assert set(ids) == set(stored_ids)
@ -128,7 +130,7 @@ def test_qdrant_add_texts_stores_embeddings_as_named_vectors(vector_name: str) -
) )
vec_store.add_texts(["lorem", "ipsum", "dolor", "sit", "amet"]) vec_store.add_texts(["lorem", "ipsum", "dolor", "sit", "amet"])
assert 5 == client.count(collection_name).count assert client.count(collection_name).count == 5
assert all( assert all(
vector_name in point.vector # type: ignore[operator] vector_name in point.vector # type: ignore[operator]
for point in client.scroll(collection_name, with_vectors=True)[0] for point in client.scroll(collection_name, with_vectors=True)[0]

View File

@ -4,4 +4,3 @@ import pytest # type: ignore[import-not-found]
@pytest.mark.compile @pytest.mark.compile
def test_placeholder() -> None: def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests.""" """Used for compiling integration tests without running any real tests."""
pass

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import uuid import uuid
from typing import Callable, Optional from typing import Callable, Optional

View File

@ -34,4 +34,4 @@ def test_qdrant_from_existing_collection_uses_same_collection(vector_name: str)
del qdrant del qdrant
client = QdrantClient(path=str(tmpdir)) client = QdrantClient(path=str(tmpdir))
assert 3 == client.count(collection_name).count assert client.count(collection_name).count == 3

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import tempfile import tempfile
import uuid import uuid
from typing import Optional from typing import Optional
@ -30,7 +32,7 @@ def test_qdrant_from_texts_stores_duplicated_texts() -> None:
del vec_store del vec_store
client = QdrantClient(path=str(tmpdir)) client = QdrantClient(path=str(tmpdir))
assert 2 == client.count(collection_name).count assert client.count(collection_name).count == 2
@pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("batch_size", [1, 64])
@ -59,7 +61,7 @@ def test_qdrant_from_texts_stores_ids(
del vec_store del vec_store
client = QdrantClient(path=str(tmpdir)) client = QdrantClient(path=str(tmpdir))
assert 2 == client.count(collection_name).count assert client.count(collection_name).count == 2
stored_ids = [point.id for point in client.scroll(collection_name)[0]] stored_ids = [point.id for point in client.scroll(collection_name)[0]]
assert set(ids) == set(stored_ids) assert set(ids) == set(stored_ids)
@ -81,7 +83,7 @@ def test_qdrant_from_texts_stores_embeddings_as_named_vectors(vector_name: str)
del vec_store del vec_store
client = QdrantClient(path=str(tmpdir)) client = QdrantClient(path=str(tmpdir))
assert 5 == client.count(collection_name).count assert client.count(collection_name).count == 5
assert all( assert all(
vector_name in point.vector # type: ignore[operator] vector_name in point.vector # type: ignore[operator]
for point in client.scroll(collection_name, with_vectors=True)[0] for point in client.scroll(collection_name, with_vectors=True)[0]
@ -90,7 +92,7 @@ def test_qdrant_from_texts_stores_embeddings_as_named_vectors(vector_name: str)
@pytest.mark.parametrize("vector_name", [None, "custom-vector"]) @pytest.mark.parametrize("vector_name", [None, "custom-vector"])
def test_qdrant_from_texts_reuses_same_collection(vector_name: Optional[str]) -> None: def test_qdrant_from_texts_reuses_same_collection(vector_name: Optional[str]) -> None:
"""Test if Qdrant.from_texts reuses the same collection""" """Test if Qdrant.from_texts reuses the same collection."""
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
@ -115,14 +117,14 @@ def test_qdrant_from_texts_reuses_same_collection(vector_name: Optional[str]) ->
del vec_store del vec_store
client = QdrantClient(path=str(tmpdir)) client = QdrantClient(path=str(tmpdir))
assert 7 == client.count(collection_name).count assert client.count(collection_name).count == 7
@pytest.mark.parametrize("vector_name", [None, "custom-vector"]) @pytest.mark.parametrize("vector_name", [None, "custom-vector"])
def test_qdrant_from_texts_raises_error_on_different_dimensionality( def test_qdrant_from_texts_raises_error_on_different_dimensionality(
vector_name: Optional[str], vector_name: Optional[str],
) -> None: ) -> None:
"""Test if Qdrant.from_texts raises an exception if dimensionality does not match""" """Test if Qdrant.from_texts raises an exception if dimensionality doesn't match."""
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
vec_store = Qdrant.from_texts( vec_store = Qdrant.from_texts(
@ -156,7 +158,7 @@ def test_qdrant_from_texts_raises_error_on_different_vector_name(
first_vector_name: Optional[str], first_vector_name: Optional[str],
second_vector_name: Optional[str], second_vector_name: Optional[str],
) -> None: ) -> None:
"""Test if Qdrant.from_texts raises an exception if vector name does not match""" """Test if Qdrant.from_texts raises an exception if vector name does not match."""
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
vec_store = Qdrant.from_texts( vec_store = Qdrant.from_texts(
@ -179,7 +181,7 @@ def test_qdrant_from_texts_raises_error_on_different_vector_name(
def test_qdrant_from_texts_raises_error_on_different_distance() -> None: def test_qdrant_from_texts_raises_error_on_different_distance() -> None:
"""Test if Qdrant.from_texts raises an exception if distance does not match""" """Test if Qdrant.from_texts raises an exception if distance does not match."""
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
vec_store = Qdrant.from_texts( vec_store = Qdrant.from_texts(
@ -211,7 +213,7 @@ def test_qdrant_from_texts_raises_error_on_different_distance() -> None:
def test_qdrant_from_texts_recreates_collection_on_force_recreate( def test_qdrant_from_texts_recreates_collection_on_force_recreate(
vector_name: Optional[str], vector_name: Optional[str],
) -> None: ) -> None:
"""Test if Qdrant.from_texts recreates the collection even if config mismatches""" """Test if Qdrant.from_texts recreates the collection even if config mismatches."""
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
@ -236,7 +238,7 @@ def test_qdrant_from_texts_recreates_collection_on_force_recreate(
del vec_store del vec_store
client = QdrantClient(path=str(tmpdir)) client = QdrantClient(path=str(tmpdir))
assert 2 == client.count(collection_name).count assert client.count(collection_name).count == 2
@pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("batch_size", [1, 64])
@ -283,6 +285,6 @@ def test_from_texts_passed_optimizers_config_and_on_disk_payload(location: str)
) )
collection_info = vec_store.client.get_collection(collection_name) collection_info = vec_store.client.get_collection(collection_name)
assert collection_info.config.params.vectors.on_disk is True # type: ignore assert collection_info.config.params.vectors.on_disk is True # type: ignore[union-attr]
assert collection_info.config.optimizer_config.memmap_threshold == 1000 assert collection_info.config.optimizer_config.memmap_threshold == 1000
assert collection_info.config.params.on_disk_payload is True assert collection_info.config.params.on_disk_payload is True

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Optional from typing import Optional
import pytest # type: ignore[import-not-found] import pytest # type: ignore[import-not-found]
@ -23,7 +25,7 @@ def test_qdrant_max_marginal_relevance_search(
"""Test end to end construction and MRR search.""" """Test end to end construction and MRR search."""
from qdrant_client import models from qdrant_client import models
filter = models.Filter( filter_ = models.Filter(
must=[ must=[
models.FieldCondition( models.FieldCondition(
key=f"{metadata_payload_key}.page", key=f"{metadata_payload_key}.page",
@ -59,7 +61,7 @@ def test_qdrant_max_marginal_relevance_search(
) )
output = docsearch.max_marginal_relevance_search( output = docsearch.max_marginal_relevance_search(
"foo", k=2, fetch_k=3, lambda_mult=0.0, filter=filter "foo", k=2, fetch_k=3, lambda_mult=0.0, filter=filter_
) )
assert_documents_equals( assert_documents_equals(
output, output,

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Optional from typing import Optional
import numpy as np import numpy as np
@ -174,7 +176,7 @@ def test_qdrant_similarity_search_with_relevance_score_with_threshold(
kwargs = {"score_threshold": score_threshold} kwargs = {"score_threshold": score_threshold}
output = docsearch.similarity_search_with_relevance_scores("foo", k=3, **kwargs) output = docsearch.similarity_search_with_relevance_scores("foo", k=3, **kwargs)
assert len(output) == 1 assert len(output) == 1
assert all([score >= score_threshold for _, score in output]) assert all(score >= score_threshold for _, score in output)
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@ -205,7 +207,7 @@ def test_qdrant_similarity_search_with_relevance_score_with_threshold_and_filter
kwargs = {"filter": positive_filter, "score_threshold": score_threshold} kwargs = {"filter": positive_filter, "score_threshold": score_threshold}
output = docsearch.similarity_search_with_relevance_scores("foo", k=3, **kwargs) output = docsearch.similarity_search_with_relevance_scores("foo", k=3, **kwargs)
assert len(output) == 1 assert len(output) == 1
assert all([score >= score_threshold for _, score in output]) assert all(score >= score_threshold for _, score in output)
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@ -280,5 +282,5 @@ def test_qdrant_similarity_search_with_relevance_scores(
output = docsearch.similarity_search_with_relevance_scores("foo", k=3) output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
assert all( assert all(
(1 >= score or np.isclose(score, 1)) and score >= 0 for _, score in output (score <= 1 or np.isclose(score, 1)) and score >= 0 for _, score in output
) )

File diff suppressed because it is too large Load Diff