mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 13:18:12 +00:00
Added search parameters to qdrant max_marginal_relevance_search (#7745)
Adds the qdrant search filter/params to the `max_marginal_relevance_search` method, which is present on others. I did not add `offset` for pagination, because it's behavior would be ambiguous in this setting (since we fetch extra and down-select). --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Kacper Łukawski <lukawski.kacper@gmail.com>
This commit is contained in:
parent
22b6549a34
commit
2bcf581a23
@ -265,6 +265,8 @@ class Qdrant(VectorStore):
|
||||
- 'quorum' - query the majority of replicas, return values present in
|
||||
all of them
|
||||
- 'all' - query all replicas, and return values present in all replicas
|
||||
**kwargs:
|
||||
Any other named arguments to pass through to QdrantClient.search()
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
@ -339,6 +341,8 @@ class Qdrant(VectorStore):
|
||||
- 'quorum' - query the majority of replicas, return values present in
|
||||
all of them
|
||||
- 'all' - query all replicas, and return values present in all replicas
|
||||
**kwargs:
|
||||
Any other named arguments to pass through to QdrantClient.search()
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query text and distance for each.
|
||||
@ -394,6 +398,9 @@ class Qdrant(VectorStore):
|
||||
- 'quorum' - query the majority of replicas, return values present in
|
||||
all of them
|
||||
- 'all' - query all replicas, and return values present in all replicas
|
||||
**kwargs:
|
||||
Any other named arguments to pass through to
|
||||
QdrantClient.async_grpc_points.Search().
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query text and distance for each.
|
||||
@ -448,6 +455,8 @@ class Qdrant(VectorStore):
|
||||
- 'quorum' - query the majority of replicas, return values present in
|
||||
all of them
|
||||
- 'all' - query all replicas, and return values present in all replicas
|
||||
**kwargs:
|
||||
Any other named arguments to pass through to QdrantClient.search()
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
@ -504,6 +513,9 @@ class Qdrant(VectorStore):
|
||||
- 'quorum' - query the majority of replicas, return values present in
|
||||
all of them
|
||||
- 'all' - query all replicas, and return values present in all replicas
|
||||
**kwargs:
|
||||
Any other named arguments to pass through to
|
||||
QdrantClient.async_grpc_points.Search().
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
@ -559,6 +571,8 @@ class Qdrant(VectorStore):
|
||||
- 'quorum' - query the majority of replicas, return values present in
|
||||
all of them
|
||||
- 'all' - query all replicas, and return values present in all replicas
|
||||
**kwargs:
|
||||
Any other named arguments to pass through to QdrantClient.search()
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query text and distance for each.
|
||||
@ -601,6 +615,56 @@ class Qdrant(VectorStore):
|
||||
for result in results
|
||||
]
|
||||
|
||||
async def _asearch_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
*,
|
||||
k: int = 4,
|
||||
filter: Optional[MetadataFilter] = None,
|
||||
search_params: Optional[common_types.SearchParams] = None,
|
||||
offset: int = 0,
|
||||
score_threshold: Optional[float] = None,
|
||||
consistency: Optional[common_types.ReadConsistency] = None,
|
||||
with_vectors: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Return results most similar to embedding vector."""
|
||||
from qdrant_client import grpc # noqa
|
||||
from qdrant_client.conversions.conversion import RestToGrpc
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
if filter is not None and isinstance(filter, dict):
|
||||
warnings.warn(
|
||||
"Using dict as a `filter` is deprecated. Please use qdrant-client "
|
||||
"filters directly: "
|
||||
"https://qdrant.tech/documentation/concepts/filtering/",
|
||||
DeprecationWarning,
|
||||
)
|
||||
qdrant_filter = self._qdrant_filter_from_dict(filter)
|
||||
else:
|
||||
qdrant_filter = filter
|
||||
|
||||
if qdrant_filter is not None and isinstance(qdrant_filter, rest.Filter):
|
||||
qdrant_filter = RestToGrpc.convert_filter(qdrant_filter)
|
||||
|
||||
response = await self.client.async_grpc_points.Search(
|
||||
grpc.SearchPoints(
|
||||
collection_name=self.collection_name,
|
||||
vector_name=self.vector_name,
|
||||
vector=embedding,
|
||||
filter=qdrant_filter,
|
||||
params=search_params,
|
||||
limit=k,
|
||||
offset=offset,
|
||||
with_payload=grpc.WithPayloadSelector(enable=True),
|
||||
with_vectors=grpc.WithVectorsSelector(enable=with_vectors),
|
||||
score_threshold=score_threshold,
|
||||
read_consistency=consistency,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
return response
|
||||
|
||||
@sync_call_fallback
|
||||
async def asimilarity_search_with_score_by_vector(
|
||||
self,
|
||||
@ -641,44 +705,23 @@ class Qdrant(VectorStore):
|
||||
- 'quorum' - query the majority of replicas, return values present in
|
||||
all of them
|
||||
- 'all' - query all replicas, and return values present in all replicas
|
||||
**kwargs:
|
||||
Any other named arguments to pass through to
|
||||
QdrantClient.async_grpc_points.Search().
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query text and distance for each.
|
||||
"""
|
||||
from qdrant_client import grpc # noqa
|
||||
from qdrant_client.conversions.conversion import RestToGrpc
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
if filter is not None and isinstance(filter, dict):
|
||||
warnings.warn(
|
||||
"Using dict as a `filter` is deprecated. Please use qdrant-client "
|
||||
"filters directly: "
|
||||
"https://qdrant.tech/documentation/concepts/filtering/",
|
||||
DeprecationWarning,
|
||||
)
|
||||
qdrant_filter = self._qdrant_filter_from_dict(filter)
|
||||
else:
|
||||
qdrant_filter = filter
|
||||
|
||||
if qdrant_filter is not None and isinstance(qdrant_filter, rest.Filter):
|
||||
qdrant_filter = RestToGrpc.convert_filter(qdrant_filter)
|
||||
|
||||
response = await self.client.async_grpc_points.Search(
|
||||
grpc.SearchPoints(
|
||||
collection_name=self.collection_name,
|
||||
vector_name=self.vector_name,
|
||||
vector=embedding,
|
||||
filter=qdrant_filter,
|
||||
params=search_params,
|
||||
limit=k,
|
||||
response = await self._asearch_with_score_by_vector(
|
||||
embedding,
|
||||
k=k,
|
||||
filter=filter,
|
||||
search_params=search_params,
|
||||
offset=offset,
|
||||
with_payload=grpc.WithPayloadSelector(enable=True),
|
||||
with_vectors=grpc.WithVectorsSelector(enable=False),
|
||||
score_threshold=score_threshold,
|
||||
read_consistency=consistency,
|
||||
consistency=consistency,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
(
|
||||
@ -696,6 +739,10 @@ class Qdrant(VectorStore):
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[MetadataFilter] = None,
|
||||
search_params: Optional[common_types.SearchParams] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
consistency: Optional[common_types.ReadConsistency] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
@ -712,12 +759,41 @@ class Qdrant(VectorStore):
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter: Filter by metadata. Defaults to None.
|
||||
search_params: Additional search params
|
||||
score_threshold:
|
||||
Define a minimal score threshold for the result.
|
||||
If defined, less similar results will not be returned.
|
||||
Score of the returned result might be higher or smaller than the
|
||||
threshold depending on the Distance function used.
|
||||
E.g. for cosine similarity only higher scores will be returned.
|
||||
consistency:
|
||||
Read consistency of the search. Defines how many replicas should be
|
||||
queried before returning the result.
|
||||
Values:
|
||||
- int - number of replicas to query, values should present in all
|
||||
queried replicas
|
||||
- 'majority' - query all replicas, but return values present in the
|
||||
majority of replicas
|
||||
- 'quorum' - query the majority of replicas, return values present in
|
||||
all of them
|
||||
- 'all' - query all replicas, and return values present in all replicas
|
||||
**kwargs:
|
||||
Any other named arguments to pass through to QdrantClient.search()
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
query_embedding = self._embed_query(query)
|
||||
return self.max_marginal_relevance_search_by_vector(
|
||||
query_embedding, k, fetch_k, lambda_mult, **kwargs
|
||||
query_embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
search_params=search_params,
|
||||
score_threshold=score_threshold,
|
||||
consistency=consistency,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@sync_call_fallback
|
||||
@ -727,6 +803,10 @@ class Qdrant(VectorStore):
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[MetadataFilter] = None,
|
||||
search_params: Optional[common_types.SearchParams] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
consistency: Optional[common_types.ReadConsistency] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
@ -743,12 +823,42 @@ class Qdrant(VectorStore):
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter: Filter by metadata. Defaults to None.
|
||||
search_params: Additional search params
|
||||
score_threshold:
|
||||
Define a minimal score threshold for the result.
|
||||
If defined, less similar results will not be returned.
|
||||
Score of the returned result might be higher or smaller than the
|
||||
threshold depending on the Distance function used.
|
||||
E.g. for cosine similarity only higher scores will be returned.
|
||||
consistency:
|
||||
Read consistency of the search. Defines how many replicas should be
|
||||
queried before returning the result.
|
||||
Values:
|
||||
- int - number of replicas to query, values should present in all
|
||||
queried replicas
|
||||
- 'majority' - query all replicas, but return values present in the
|
||||
majority of replicas
|
||||
- 'quorum' - query the majority of replicas, return values present in
|
||||
all of them
|
||||
- 'all' - query all replicas, and return values present in all replicas
|
||||
**kwargs:
|
||||
Any other named arguments to pass through to
|
||||
QdrantClient.async_grpc_points.Search().
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
query_embedding = self._embed_query(query)
|
||||
return await self.amax_marginal_relevance_search_by_vector(
|
||||
query_embedding, k, fetch_k, lambda_mult, **kwargs
|
||||
query_embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
search_params=search_params,
|
||||
score_threshold=score_threshold,
|
||||
consistency=consistency,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
@ -757,6 +867,10 @@ class Qdrant(VectorStore):
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[MetadataFilter] = None,
|
||||
search_params: Optional[common_types.SearchParams] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
consistency: Optional[common_types.ReadConsistency] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
@ -772,11 +886,40 @@ class Qdrant(VectorStore):
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter: Filter by metadata. Defaults to None.
|
||||
search_params: Additional search params
|
||||
score_threshold:
|
||||
Define a minimal score threshold for the result.
|
||||
If defined, less similar results will not be returned.
|
||||
Score of the returned result might be higher or smaller than the
|
||||
threshold depending on the Distance function used.
|
||||
E.g. for cosine similarity only higher scores will be returned.
|
||||
consistency:
|
||||
Read consistency of the search. Defines how many replicas should be
|
||||
queried before returning the result.
|
||||
Values:
|
||||
- int - number of replicas to query, values should present in all
|
||||
queried replicas
|
||||
- 'majority' - query all replicas, but return values present in the
|
||||
majority of replicas
|
||||
- 'quorum' - query the majority of replicas, return values present in
|
||||
all of them
|
||||
- 'all' - query all replicas, and return values present in all replicas
|
||||
**kwargs:
|
||||
Any other named arguments to pass through to QdrantClient.search()
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
results = self.max_marginal_relevance_search_with_score_by_vector(
|
||||
embedding=embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
|
||||
embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
search_params=search_params,
|
||||
score_threshold=score_threshold,
|
||||
consistency=consistency,
|
||||
**kwargs,
|
||||
)
|
||||
return list(map(itemgetter(0), results))
|
||||
|
||||
@ -787,6 +930,10 @@ class Qdrant(VectorStore):
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[MetadataFilter] = None,
|
||||
search_params: Optional[common_types.SearchParams] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
consistency: Optional[common_types.ReadConsistency] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
@ -801,12 +948,42 @@ class Qdrant(VectorStore):
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter: Filter by metadata. Defaults to None.
|
||||
search_params: Additional search params
|
||||
score_threshold:
|
||||
Define a minimal score threshold for the result.
|
||||
If defined, less similar results will not be returned.
|
||||
Score of the returned result might be higher or smaller than the
|
||||
threshold depending on the Distance function used.
|
||||
E.g. for cosine similarity only higher scores will be returned.
|
||||
consistency:
|
||||
Read consistency of the search. Defines how many replicas should be
|
||||
queried before returning the result.
|
||||
Values:
|
||||
- int - number of replicas to query, values should present in all
|
||||
queried replicas
|
||||
- 'majority' - query all replicas, but return values present in the
|
||||
majority of replicas
|
||||
- 'quorum' - query the majority of replicas, return values present in
|
||||
all of them
|
||||
- 'all' - query all replicas, and return values present in all replicas
|
||||
**kwargs:
|
||||
Any other named arguments to pass through to
|
||||
QdrantClient.async_grpc_points.Search().
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance and distance for
|
||||
each.
|
||||
"""
|
||||
results = await self.amax_marginal_relevance_search_with_score_by_vector(
|
||||
embedding, k, fetch_k, lambda_mult, **kwargs
|
||||
embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
search_params=search_params,
|
||||
score_threshold=score_threshold,
|
||||
consistency=consistency,
|
||||
**kwargs,
|
||||
)
|
||||
return list(map(itemgetter(0), results))
|
||||
|
||||
@ -816,6 +993,10 @@ class Qdrant(VectorStore):
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[MetadataFilter] = None,
|
||||
search_params: Optional[common_types.SearchParams] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
consistency: Optional[common_types.ReadConsistency] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
@ -830,6 +1011,27 @@ class Qdrant(VectorStore):
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter: Filter by metadata. Defaults to None.
|
||||
search_params: Additional search params
|
||||
score_threshold:
|
||||
Define a minimal score threshold for the result.
|
||||
If defined, less similar results will not be returned.
|
||||
Score of the returned result might be higher or smaller than the
|
||||
threshold depending on the Distance function used.
|
||||
E.g. for cosine similarity only higher scores will be returned.
|
||||
consistency:
|
||||
Read consistency of the search. Defines how many replicas should be
|
||||
queried before returning the result.
|
||||
Values:
|
||||
- int - number of replicas to query, values should present in all
|
||||
queried replicas
|
||||
- 'majority' - query all replicas, but return values present in the
|
||||
majority of replicas
|
||||
- 'quorum' - query the majority of replicas, return values present in
|
||||
all of them
|
||||
- 'all' - query all replicas, and return values present in all replicas
|
||||
**kwargs:
|
||||
Any other named arguments to pass through to QdrantClient.search()
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance and distance for
|
||||
each.
|
||||
@ -841,9 +1043,14 @@ class Qdrant(VectorStore):
|
||||
results = self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
query_vector=query_vector,
|
||||
query_filter=filter,
|
||||
search_params=search_params,
|
||||
limit=fetch_k,
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
limit=fetch_k,
|
||||
score_threshold=score_threshold,
|
||||
consistency=consistency,
|
||||
**kwargs,
|
||||
)
|
||||
embeddings = [
|
||||
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
|
||||
@ -871,6 +1078,10 @@ class Qdrant(VectorStore):
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[MetadataFilter] = None,
|
||||
search_params: Optional[common_types.SearchParams] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
consistency: Optional[common_types.ReadConsistency] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
@ -889,18 +1100,17 @@ class Qdrant(VectorStore):
|
||||
List of Documents selected by maximal marginal relevance and distance for
|
||||
each.
|
||||
"""
|
||||
from qdrant_client import grpc # noqa
|
||||
from qdrant_client.conversions.conversion import GrpcToRest
|
||||
|
||||
response = await self.client.async_grpc_points.Search(
|
||||
grpc.SearchPoints(
|
||||
collection_name=self.collection_name,
|
||||
vector_name=self.vector_name,
|
||||
vector=embedding,
|
||||
with_payload=grpc.WithPayloadSelector(enable=True),
|
||||
with_vectors=grpc.WithVectorsSelector(enable=True),
|
||||
limit=fetch_k,
|
||||
)
|
||||
response = await self._asearch_with_score_by_vector(
|
||||
embedding,
|
||||
k=fetch_k,
|
||||
filter=filter,
|
||||
search_params=search_params,
|
||||
score_threshold=score_threshold,
|
||||
consistency=consistency,
|
||||
with_vectors=True,
|
||||
**kwargs,
|
||||
)
|
||||
results = [
|
||||
GrpcToRest.convert_vectors(result.vectors) for result in response.result
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from qdrant_client import models
|
||||
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import Qdrant
|
||||
@ -20,6 +21,17 @@ def test_qdrant_max_marginal_relevance_search(
|
||||
vector_name: Optional[str],
|
||||
) -> None:
|
||||
"""Test end to end construction and MRR search."""
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=f"{metadata_payload_key}.page",
|
||||
match=models.MatchValue(
|
||||
value=2,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = Qdrant.from_texts(
|
||||
@ -40,3 +52,10 @@ def test_qdrant_max_marginal_relevance_search(
|
||||
Document(page_content="foo", metadata={"page": 0}),
|
||||
Document(page_content="baz", metadata={"page": 2}),
|
||||
]
|
||||
|
||||
output = docsearch.max_marginal_relevance_search(
|
||||
"foo", k=2, fetch_k=3, lambda_mult=0.0, filter=filter
|
||||
)
|
||||
assert output == [
|
||||
Document(page_content="baz", metadata={"page": 2}),
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user