From 2bcf581a2302693e5df9d185ab64e62bf0bd9bc5 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Thu, 24 Aug 2023 17:11:30 -0400 Subject: [PATCH] Added search parameters to qdrant max_marginal_relevance_search (#7745) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the qdrant search filter/params to the `max_marginal_relevance_search` method, which is present on others. I did not add `offset` for pagination, because it's behavior would be ambiguous in this setting (since we fetch extra and down-select). --------- Co-authored-by: Bagatur Co-authored-by: Kacper Ɓukawski --- .../langchain/vectorstores/qdrant.py | 306 +++++++++++++++--- .../qdrant/test_max_marginal_relevance.py | 19 ++ 2 files changed, 277 insertions(+), 48 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/qdrant.py b/libs/langchain/langchain/vectorstores/qdrant.py index 7b9d9869ce9..cdc5bea8efb 100644 --- a/libs/langchain/langchain/vectorstores/qdrant.py +++ b/libs/langchain/langchain/vectorstores/qdrant.py @@ -265,6 +265,8 @@ class Qdrant(VectorStore): - 'quorum' - query the majority of replicas, return values present in all of them - 'all' - query all replicas, and return values present in all replicas + **kwargs: + Any other named arguments to pass through to QdrantClient.search() Returns: List of Documents most similar to the query. @@ -339,6 +341,8 @@ class Qdrant(VectorStore): - 'quorum' - query the majority of replicas, return values present in all of them - 'all' - query all replicas, and return values present in all replicas + **kwargs: + Any other named arguments to pass through to QdrantClient.search() Returns: List of documents most similar to the query text and distance for each. @@ -394,6 +398,9 @@ class Qdrant(VectorStore): - 'quorum' - query the majority of replicas, return values present in all of them - 'all' - query all replicas, and return values present in all replicas + **kwargs: + Any other named arguments to pass through to + QdrantClient.async_grpc_points.Search(). Returns: List of documents most similar to the query text and distance for each. @@ -448,6 +455,8 @@ class Qdrant(VectorStore): - 'quorum' - query the majority of replicas, return values present in all of them - 'all' - query all replicas, and return values present in all replicas + **kwargs: + Any other named arguments to pass through to QdrantClient.search() Returns: List of Documents most similar to the query. @@ -504,6 +513,9 @@ class Qdrant(VectorStore): - 'quorum' - query the majority of replicas, return values present in all of them - 'all' - query all replicas, and return values present in all replicas + **kwargs: + Any other named arguments to pass through to + QdrantClient.async_grpc_points.Search(). Returns: List of Documents most similar to the query. @@ -559,6 +571,8 @@ class Qdrant(VectorStore): - 'quorum' - query the majority of replicas, return values present in all of them - 'all' - query all replicas, and return values present in all replicas + **kwargs: + Any other named arguments to pass through to QdrantClient.search() Returns: List of documents most similar to the query text and distance for each. @@ -601,6 +615,56 @@ class Qdrant(VectorStore): for result in results ] + async def _asearch_with_score_by_vector( + self, + embedding: List[float], + *, + k: int = 4, + filter: Optional[MetadataFilter] = None, + search_params: Optional[common_types.SearchParams] = None, + offset: int = 0, + score_threshold: Optional[float] = None, + consistency: Optional[common_types.ReadConsistency] = None, + with_vectors: bool = False, + **kwargs: Any, + ) -> Any: + """Return results most similar to embedding vector.""" + from qdrant_client import grpc # noqa + from qdrant_client.conversions.conversion import RestToGrpc + from qdrant_client.http import models as rest + + if filter is not None and isinstance(filter, dict): + warnings.warn( + "Using dict as a `filter` is deprecated. Please use qdrant-client " + "filters directly: " + "https://qdrant.tech/documentation/concepts/filtering/", + DeprecationWarning, + ) + qdrant_filter = self._qdrant_filter_from_dict(filter) + else: + qdrant_filter = filter + + if qdrant_filter is not None and isinstance(qdrant_filter, rest.Filter): + qdrant_filter = RestToGrpc.convert_filter(qdrant_filter) + + response = await self.client.async_grpc_points.Search( + grpc.SearchPoints( + collection_name=self.collection_name, + vector_name=self.vector_name, + vector=embedding, + filter=qdrant_filter, + params=search_params, + limit=k, + offset=offset, + with_payload=grpc.WithPayloadSelector(enable=True), + with_vectors=grpc.WithVectorsSelector(enable=with_vectors), + score_threshold=score_threshold, + read_consistency=consistency, + **kwargs, + ) + ) + return response + @sync_call_fallback async def asimilarity_search_with_score_by_vector( self, @@ -641,43 +705,22 @@ class Qdrant(VectorStore): - 'quorum' - query the majority of replicas, return values present in all of them - 'all' - query all replicas, and return values present in all replicas + **kwargs: + Any other named arguments to pass through to + QdrantClient.async_grpc_points.Search(). Returns: List of documents most similar to the query text and distance for each. """ - from qdrant_client import grpc # noqa - from qdrant_client.conversions.conversion import RestToGrpc - from qdrant_client.http import models as rest - - if filter is not None and isinstance(filter, dict): - warnings.warn( - "Using dict as a `filter` is deprecated. Please use qdrant-client " - "filters directly: " - "https://qdrant.tech/documentation/concepts/filtering/", - DeprecationWarning, - ) - qdrant_filter = self._qdrant_filter_from_dict(filter) - else: - qdrant_filter = filter - - if qdrant_filter is not None and isinstance(qdrant_filter, rest.Filter): - qdrant_filter = RestToGrpc.convert_filter(qdrant_filter) - - response = await self.client.async_grpc_points.Search( - grpc.SearchPoints( - collection_name=self.collection_name, - vector_name=self.vector_name, - vector=embedding, - filter=qdrant_filter, - params=search_params, - limit=k, - offset=offset, - with_payload=grpc.WithPayloadSelector(enable=True), - with_vectors=grpc.WithVectorsSelector(enable=False), - score_threshold=score_threshold, - read_consistency=consistency, - **kwargs, - ) + response = await self._asearch_with_score_by_vector( + embedding, + k=k, + filter=filter, + search_params=search_params, + offset=offset, + score_threshold=score_threshold, + consistency=consistency, + **kwargs, ) return [ @@ -696,6 +739,10 @@ class Qdrant(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, + filter: Optional[MetadataFilter] = None, + search_params: Optional[common_types.SearchParams] = None, + score_threshold: Optional[float] = None, + consistency: Optional[common_types.ReadConsistency] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -712,12 +759,41 @@ class Qdrant(VectorStore): of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. + filter: Filter by metadata. Defaults to None. + search_params: Additional search params + score_threshold: + Define a minimal score threshold for the result. + If defined, less similar results will not be returned. + Score of the returned result might be higher or smaller than the + threshold depending on the Distance function used. + E.g. for cosine similarity only higher scores will be returned. + consistency: + Read consistency of the search. Defines how many replicas should be + queried before returning the result. + Values: + - int - number of replicas to query, values should present in all + queried replicas + - 'majority' - query all replicas, but return values present in the + majority of replicas + - 'quorum' - query the majority of replicas, return values present in + all of them + - 'all' - query all replicas, and return values present in all replicas + **kwargs: + Any other named arguments to pass through to QdrantClient.search() Returns: List of Documents selected by maximal marginal relevance. """ query_embedding = self._embed_query(query) return self.max_marginal_relevance_search_by_vector( - query_embedding, k, fetch_k, lambda_mult, **kwargs + query_embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + search_params=search_params, + score_threshold=score_threshold, + consistency=consistency, + **kwargs, ) @sync_call_fallback @@ -727,6 +803,10 @@ class Qdrant(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, + filter: Optional[MetadataFilter] = None, + search_params: Optional[common_types.SearchParams] = None, + score_threshold: Optional[float] = None, + consistency: Optional[common_types.ReadConsistency] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -743,12 +823,42 @@ class Qdrant(VectorStore): of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. + filter: Filter by metadata. Defaults to None. + search_params: Additional search params + score_threshold: + Define a minimal score threshold for the result. + If defined, less similar results will not be returned. + Score of the returned result might be higher or smaller than the + threshold depending on the Distance function used. + E.g. for cosine similarity only higher scores will be returned. + consistency: + Read consistency of the search. Defines how many replicas should be + queried before returning the result. + Values: + - int - number of replicas to query, values should present in all + queried replicas + - 'majority' - query all replicas, but return values present in the + majority of replicas + - 'quorum' - query the majority of replicas, return values present in + all of them + - 'all' - query all replicas, and return values present in all replicas + **kwargs: + Any other named arguments to pass through to + QdrantClient.async_grpc_points.Search(). Returns: List of Documents selected by maximal marginal relevance. """ query_embedding = self._embed_query(query) return await self.amax_marginal_relevance_search_by_vector( - query_embedding, k, fetch_k, lambda_mult, **kwargs + query_embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + search_params=search_params, + score_threshold=score_threshold, + consistency=consistency, + **kwargs, ) def max_marginal_relevance_search_by_vector( @@ -757,6 +867,10 @@ class Qdrant(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, + filter: Optional[MetadataFilter] = None, + search_params: Optional[common_types.SearchParams] = None, + score_threshold: Optional[float] = None, + consistency: Optional[common_types.ReadConsistency] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -772,11 +886,40 @@ class Qdrant(VectorStore): of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. + filter: Filter by metadata. Defaults to None. + search_params: Additional search params + score_threshold: + Define a minimal score threshold for the result. + If defined, less similar results will not be returned. + Score of the returned result might be higher or smaller than the + threshold depending on the Distance function used. + E.g. for cosine similarity only higher scores will be returned. + consistency: + Read consistency of the search. Defines how many replicas should be + queried before returning the result. + Values: + - int - number of replicas to query, values should present in all + queried replicas + - 'majority' - query all replicas, but return values present in the + majority of replicas + - 'quorum' - query the majority of replicas, return values present in + all of them + - 'all' - query all replicas, and return values present in all replicas + **kwargs: + Any other named arguments to pass through to QdrantClient.search() Returns: List of Documents selected by maximal marginal relevance. """ results = self.max_marginal_relevance_search_with_score_by_vector( - embedding=embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + search_params=search_params, + score_threshold=score_threshold, + consistency=consistency, + **kwargs, ) return list(map(itemgetter(0), results)) @@ -787,6 +930,10 @@ class Qdrant(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, + filter: Optional[MetadataFilter] = None, + search_params: Optional[common_types.SearchParams] = None, + score_threshold: Optional[float] = None, + consistency: Optional[common_types.ReadConsistency] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -801,12 +948,42 @@ class Qdrant(VectorStore): of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. + filter: Filter by metadata. Defaults to None. + search_params: Additional search params + score_threshold: + Define a minimal score threshold for the result. + If defined, less similar results will not be returned. + Score of the returned result might be higher or smaller than the + threshold depending on the Distance function used. + E.g. for cosine similarity only higher scores will be returned. + consistency: + Read consistency of the search. Defines how many replicas should be + queried before returning the result. + Values: + - int - number of replicas to query, values should present in all + queried replicas + - 'majority' - query all replicas, but return values present in the + majority of replicas + - 'quorum' - query the majority of replicas, return values present in + all of them + - 'all' - query all replicas, and return values present in all replicas + **kwargs: + Any other named arguments to pass through to + QdrantClient.async_grpc_points.Search(). Returns: List of Documents selected by maximal marginal relevance and distance for each. """ results = await self.amax_marginal_relevance_search_with_score_by_vector( - embedding, k, fetch_k, lambda_mult, **kwargs + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + search_params=search_params, + score_threshold=score_threshold, + consistency=consistency, + **kwargs, ) return list(map(itemgetter(0), results)) @@ -816,6 +993,10 @@ class Qdrant(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, + filter: Optional[MetadataFilter] = None, + search_params: Optional[common_types.SearchParams] = None, + score_threshold: Optional[float] = None, + consistency: Optional[common_types.ReadConsistency] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance. @@ -830,6 +1011,27 @@ class Qdrant(VectorStore): of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. + filter: Filter by metadata. Defaults to None. + search_params: Additional search params + score_threshold: + Define a minimal score threshold for the result. + If defined, less similar results will not be returned. + Score of the returned result might be higher or smaller than the + threshold depending on the Distance function used. + E.g. for cosine similarity only higher scores will be returned. + consistency: + Read consistency of the search. Defines how many replicas should be + queried before returning the result. + Values: + - int - number of replicas to query, values should present in all + queried replicas + - 'majority' - query all replicas, but return values present in the + majority of replicas + - 'quorum' - query the majority of replicas, return values present in + all of them + - 'all' - query all replicas, and return values present in all replicas + **kwargs: + Any other named arguments to pass through to QdrantClient.search() Returns: List of Documents selected by maximal marginal relevance and distance for each. @@ -841,9 +1043,14 @@ class Qdrant(VectorStore): results = self.client.search( collection_name=self.collection_name, query_vector=query_vector, + query_filter=filter, + search_params=search_params, + limit=fetch_k, with_payload=True, with_vectors=True, - limit=fetch_k, + score_threshold=score_threshold, + consistency=consistency, + **kwargs, ) embeddings = [ result.vector.get(self.vector_name) # type: ignore[index, union-attr] @@ -871,6 +1078,10 @@ class Qdrant(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, + filter: Optional[MetadataFilter] = None, + search_params: Optional[common_types.SearchParams] = None, + score_threshold: Optional[float] = None, + consistency: Optional[common_types.ReadConsistency] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance. @@ -889,18 +1100,17 @@ class Qdrant(VectorStore): List of Documents selected by maximal marginal relevance and distance for each. """ - from qdrant_client import grpc # noqa from qdrant_client.conversions.conversion import GrpcToRest - response = await self.client.async_grpc_points.Search( - grpc.SearchPoints( - collection_name=self.collection_name, - vector_name=self.vector_name, - vector=embedding, - with_payload=grpc.WithPayloadSelector(enable=True), - with_vectors=grpc.WithVectorsSelector(enable=True), - limit=fetch_k, - ) + response = await self._asearch_with_score_by_vector( + embedding, + k=fetch_k, + filter=filter, + search_params=search_params, + score_threshold=score_threshold, + consistency=consistency, + with_vectors=True, + **kwargs, ) results = [ GrpcToRest.convert_vectors(result.vectors) for result in response.result diff --git a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_max_marginal_relevance.py b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_max_marginal_relevance.py index 5a383b36cc4..71d1643b789 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_max_marginal_relevance.py +++ b/libs/langchain/tests/integration_tests/vectorstores/qdrant/test_max_marginal_relevance.py @@ -1,6 +1,7 @@ from typing import Optional import pytest +from qdrant_client import models from langchain.schema import Document from langchain.vectorstores import Qdrant @@ -20,6 +21,17 @@ def test_qdrant_max_marginal_relevance_search( vector_name: Optional[str], ) -> None: """Test end to end construction and MRR search.""" + filter = models.Filter( + must=[ + models.FieldCondition( + key=f"{metadata_payload_key}.page", + match=models.MatchValue( + value=2, + ), + ), + ], + ) + texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] docsearch = Qdrant.from_texts( @@ -40,3 +52,10 @@ def test_qdrant_max_marginal_relevance_search( Document(page_content="foo", metadata={"page": 0}), Document(page_content="baz", metadata={"page": 2}), ] + + output = docsearch.max_marginal_relevance_search( + "foo", k=2, fetch_k=3, lambda_mult=0.0, filter=filter + ) + assert output == [ + Document(page_content="baz", metadata={"page": 2}), + ]