Implement async support in Qdrant local mode (#8001)

I've extended the support of async API to local Qdrant mode. It is faked
but allows prototyping without spinning a container. The tests are
improved to test the in-memory case as well.

@baskaryan @rlancemartin @eyurtsev @agola11
This commit is contained in:
Kacper Łukawski 2023-07-21 04:04:33 +02:00 committed by GitHub
parent 7717c24fc4
commit ed6a5532ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 196 additions and 83 deletions

View File

@ -114,7 +114,6 @@
" \"rating\": 9.9,\n", " \"rating\": 9.9,\n",
" \"director\": \"Andrei Tarkovsky\",\n", " \"director\": \"Andrei Tarkovsky\",\n",
" \"genre\": \"science fiction\",\n", " \"genre\": \"science fiction\",\n",
" \"rating\": 9.9,\n",
" },\n", " },\n",
" ),\n", " ),\n",
"]\n", "]\n",

View File

@ -1,6 +1,8 @@
"""Wrapper around Qdrant vector database.""" """Wrapper around Qdrant vector database."""
from __future__ import annotations from __future__ import annotations
import asyncio
import functools
import uuid import uuid
import warnings import warnings
from itertools import islice from itertools import islice
@ -40,6 +42,30 @@ class QdrantException(Exception):
"""Base class for all the Qdrant related exceptions""" """Base class for all the Qdrant related exceptions"""
def sync_call_fallback(method: Callable) -> Callable:
"""
Decorator to call the synchronous method of the class if the async method is not
implemented. This decorator might be only used for the methods that are defined
as async in the class.
"""
@functools.wraps(method)
async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
try:
return await method(self, *args, **kwargs)
except NotImplementedError:
# If the async method is not implemented, call the synchronous method
# by removing the first letter from the method name. For example,
# if the async method is called ``aaad_texts``, the synchronous method
# will be called ``aad_texts``.
sync_method = functools.partial(
getattr(self, method.__name__[1:]), *args, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, sync_method)
return wrapper
class Qdrant(VectorStore): class Qdrant(VectorStore):
"""Wrapper around Qdrant vector database. """Wrapper around Qdrant vector database.
@ -155,6 +181,7 @@ class Qdrant(VectorStore):
return added_ids return added_ids
@sync_call_fallback
async def aadd_texts( async def aadd_texts(
self, self,
texts: Iterable[str], texts: Iterable[str],
@ -250,6 +277,7 @@ class Qdrant(VectorStore):
) )
return list(map(itemgetter(0), results)) return list(map(itemgetter(0), results))
@sync_call_fallback
async def asimilarity_search( async def asimilarity_search(
self, self,
query: str, query: str,
@ -322,6 +350,7 @@ class Qdrant(VectorStore):
**kwargs, **kwargs,
) )
@sync_call_fallback
async def asimilarity_search_with_score( async def asimilarity_search_with_score(
self, self,
query: str, query: str,
@ -431,6 +460,7 @@ class Qdrant(VectorStore):
) )
return list(map(itemgetter(0), results)) return list(map(itemgetter(0), results))
@sync_call_fallback
async def asimilarity_search_by_vector( async def asimilarity_search_by_vector(
self, self,
embedding: List[float], embedding: List[float],
@ -567,6 +597,7 @@ class Qdrant(VectorStore):
for result in results for result in results
] ]
@sync_call_fallback
async def asimilarity_search_with_score_by_vector( async def asimilarity_search_with_score_by_vector(
self, self,
embedding: List[float], embedding: List[float],
@ -685,6 +716,7 @@ class Qdrant(VectorStore):
query_embedding, k, fetch_k, lambda_mult, **kwargs query_embedding, k, fetch_k, lambda_mult, **kwargs
) )
@sync_call_fallback
async def amax_marginal_relevance_search( async def amax_marginal_relevance_search(
self, self,
query: str, query: str,
@ -739,33 +771,12 @@ class Qdrant(VectorStore):
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
query_vector = embedding results = self.max_marginal_relevance_search_with_score_by_vector(
if self.vector_name is not None: embedding=embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
with_payload=True,
with_vectors=True,
limit=fetch_k,
) )
embeddings = [ return list(map(itemgetter(0), results))
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
if self.vector_name is not None
else result.vector
for result in results
]
mmr_selected = maximal_marginal_relevance(
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
)
return [
self._document_from_scored_point(
results[i], self.content_payload_key, self.metadata_payload_key
)
for i in mmr_selected
]
@sync_call_fallback
async def amax_marginal_relevance_search_by_vector( async def amax_marginal_relevance_search_by_vector(
self, self,
embedding: List[float], embedding: List[float],
@ -795,6 +806,61 @@ class Qdrant(VectorStore):
) )
return list(map(itemgetter(0), results)) return list(map(itemgetter(0), results))
def max_marginal_relevance_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance and distance for
each.
"""
query_vector = embedding
if self.vector_name is not None:
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
with_payload=True,
with_vectors=True,
limit=fetch_k,
)
embeddings = [
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
if self.vector_name is not None
else result.vector
for result in results
]
mmr_selected = maximal_marginal_relevance(
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
)
return [
(
self._document_from_scored_point(
results[i], self.content_payload_key, self.metadata_payload_key
),
results[i].score,
)
for i in mmr_selected
]
@sync_call_fallback
async def amax_marginal_relevance_search_with_score_by_vector( async def amax_marginal_relevance_search_with_score_by_vector(
self, self,
embedding: List[float], embedding: List[float],
@ -1038,7 +1104,6 @@ class Qdrant(VectorStore):
content_payload_key, content_payload_key,
metadata_payload_key, metadata_payload_key,
vector_name, vector_name,
batch_size,
shard_number, shard_number,
replication_factor, replication_factor,
write_consistency_factor, write_consistency_factor,
@ -1055,6 +1120,7 @@ class Qdrant(VectorStore):
return qdrant return qdrant
@classmethod @classmethod
@sync_call_fallback
async def afrom_texts( async def afrom_texts(
cls: Type[Qdrant], cls: Type[Qdrant],
texts: List[str], texts: List[str],
@ -1214,7 +1280,6 @@ class Qdrant(VectorStore):
content_payload_key, content_payload_key,
metadata_payload_key, metadata_payload_key,
vector_name, vector_name,
batch_size,
shard_number, shard_number,
replication_factor, replication_factor,
write_consistency_factor, write_consistency_factor,
@ -1253,7 +1318,6 @@ class Qdrant(VectorStore):
content_payload_key: str = CONTENT_KEY, content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY, metadata_payload_key: str = METADATA_KEY,
vector_name: Optional[str] = VECTOR_NAME, vector_name: Optional[str] = VECTOR_NAME,
batch_size: int = 64,
shard_number: Optional[int] = None, shard_number: Optional[int] = None,
replication_factor: Optional[int] = None, replication_factor: Optional[int] = None,
write_consistency_factor: Optional[int] = None, write_consistency_factor: Optional[int] = None,

View File

@ -0,0 +1,13 @@
import logging
from typing import List
from tests.integration_tests.vectorstores.qdrant.common import qdrant_is_not_running
logger = logging.getLogger(__name__)
def qdrant_locations() -> List[str]:
if qdrant_is_not_running():
logger.warning("Running Qdrant async tests in memory mode only.")
return [":memory:"]
return ["http://localhost:6333", ":memory:"]

View File

@ -7,23 +7,23 @@ from langchain.vectorstores import Qdrant
from tests.integration_tests.vectorstores.fake_embeddings import ( from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings, ConsistentFakeEmbeddings,
) )
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import ( # noqa
from .common import qdrant_is_not_running qdrant_locations,
# Skipping all the tests in the module if Qdrant is not running on localhost.
pytestmark = pytest.mark.skipif(
qdrant_is_not_running(), reason="Qdrant server is not running"
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("batch_size", [1, 64])
async def test_qdrant_aadd_texts_returns_all_ids(batch_size: int) -> None: @pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_aadd_texts_returns_all_ids(
batch_size: int, qdrant_location: str
) -> None:
"""Test end to end Qdrant.aadd_texts returns unique ids.""" """Test end to end Qdrant.aadd_texts returns unique ids."""
docsearch: Qdrant = Qdrant.from_texts( docsearch: Qdrant = Qdrant.from_texts(
["foobar"], ["foobar"],
ConsistentFakeEmbeddings(), ConsistentFakeEmbeddings(),
batch_size=batch_size, batch_size=batch_size,
location=qdrant_location,
) )
ids = await docsearch.aadd_texts(["foo", "bar", "baz"]) ids = await docsearch.aadd_texts(["foo", "bar", "baz"])
@ -33,14 +33,15 @@ async def test_qdrant_aadd_texts_returns_all_ids(batch_size: int) -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_aadd_texts_stores_duplicated_texts( async def test_qdrant_aadd_texts_stores_duplicated_texts(
vector_name: Optional[str], vector_name: Optional[str], qdrant_location: str
) -> None: ) -> None:
"""Test end to end Qdrant.aadd_texts stores duplicated texts separately.""" """Test end to end Qdrant.aadd_texts stores duplicated texts separately."""
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
from qdrant_client.http import models as rest from qdrant_client.http import models as rest
client = QdrantClient() client = QdrantClient(location=qdrant_location)
collection_name = "test" collection_name = "test"
vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE) vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE)
if vector_name is not None: if vector_name is not None:
@ -61,7 +62,10 @@ async def test_qdrant_aadd_texts_stores_duplicated_texts(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("batch_size", [1, 64])
async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None: @pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_aadd_texts_stores_ids(
batch_size: int, qdrant_location: str
) -> None:
"""Test end to end Qdrant.aadd_texts stores provided ids.""" """Test end to end Qdrant.aadd_texts stores provided ids."""
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
@ -70,7 +74,7 @@ async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None:
"cdc1aa36-d6ab-4fb2-8a94-56674fd27484", "cdc1aa36-d6ab-4fb2-8a94-56674fd27484",
] ]
client = QdrantClient() client = QdrantClient(location=qdrant_location)
collection_name = "test" collection_name = "test"
client.recreate_collection( client.recreate_collection(
collection_name, collection_name,
@ -90,15 +94,16 @@ async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", ["custom-vector"]) @pytest.mark.parametrize("vector_name", ["custom-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_aadd_texts_stores_embeddings_as_named_vectors( async def test_qdrant_aadd_texts_stores_embeddings_as_named_vectors(
vector_name: str, vector_name: str, qdrant_location: str
) -> None: ) -> None:
"""Test end to end Qdrant.aadd_texts stores named vectors if name is provided.""" """Test end to end Qdrant.aadd_texts stores named vectors if name is provided."""
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
collection_name = "test" collection_name = "test"
client = QdrantClient() client = QdrantClient(location=qdrant_location)
client.recreate_collection( client.recreate_collection(
collection_name, collection_name,
vectors_config={ vectors_config={

View File

@ -9,56 +9,53 @@ from langchain.vectorstores.qdrant import QdrantException
from tests.integration_tests.vectorstores.fake_embeddings import ( from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings, ConsistentFakeEmbeddings,
) )
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import (
from .common import qdrant_is_not_running qdrant_locations,
# Skipping all the tests in the module if Qdrant is not running on localhost.
pytestmark = pytest.mark.skipif(
qdrant_is_not_running(), reason="Qdrant server is not running"
) )
from tests.integration_tests.vectorstores.qdrant.common import qdrant_is_not_running
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_qdrant_from_texts_stores_duplicated_texts() -> None: @pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_from_texts_stores_duplicated_texts(qdrant_location: str) -> None:
"""Test end to end Qdrant.afrom_texts stores duplicated texts separately.""" """Test end to end Qdrant.afrom_texts stores duplicated texts separately."""
from qdrant_client import QdrantClient
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
await Qdrant.afrom_texts( vec_store = await Qdrant.afrom_texts(
["abc", "abc"], ["abc", "abc"],
ConsistentFakeEmbeddings(), ConsistentFakeEmbeddings(),
collection_name=collection_name, collection_name=collection_name,
location=qdrant_location,
) )
client = QdrantClient() client = vec_store.client
assert 2 == client.count(collection_name).count assert 2 == client.count(collection_name).count
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_from_texts_stores_ids( async def test_qdrant_from_texts_stores_ids(
batch_size: int, vector_name: Optional[str] batch_size: int, vector_name: Optional[str], qdrant_location: str
) -> None: ) -> None:
"""Test end to end Qdrant.afrom_texts stores provided ids.""" """Test end to end Qdrant.afrom_texts stores provided ids."""
from qdrant_client import QdrantClient
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
ids = [ ids = [
"fa38d572-4c31-4579-aedc-1960d79df6df", "fa38d572-4c31-4579-aedc-1960d79df6df",
"cdc1aa36-d6ab-4fb2-8a94-56674fd27484", "cdc1aa36-d6ab-4fb2-8a94-56674fd27484",
] ]
await Qdrant.afrom_texts( vec_store = await Qdrant.afrom_texts(
["abc", "def"], ["abc", "def"],
ConsistentFakeEmbeddings(), ConsistentFakeEmbeddings(),
ids=ids, ids=ids,
collection_name=collection_name, collection_name=collection_name,
batch_size=batch_size, batch_size=batch_size,
vector_name=vector_name, vector_name=vector_name,
location=qdrant_location,
) )
client = QdrantClient() client = vec_store.client
assert 2 == client.count(collection_name).count assert 2 == client.count(collection_name).count
stored_ids = [point.id for point in client.scroll(collection_name)[0]] stored_ids = [point.id for point in client.scroll(collection_name)[0]]
assert set(ids) == set(stored_ids) assert set(ids) == set(stored_ids)
@ -66,22 +63,23 @@ async def test_qdrant_from_texts_stores_ids(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", ["custom-vector"]) @pytest.mark.parametrize("vector_name", ["custom-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_from_texts_stores_embeddings_as_named_vectors( async def test_qdrant_from_texts_stores_embeddings_as_named_vectors(
vector_name: str, vector_name: str,
qdrant_location: str,
) -> None: ) -> None:
"""Test end to end Qdrant.afrom_texts stores named vectors if name is provided.""" """Test end to end Qdrant.afrom_texts stores named vectors if name is provided."""
from qdrant_client import QdrantClient
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
await Qdrant.afrom_texts( vec_store = await Qdrant.afrom_texts(
["lorem", "ipsum", "dolor", "sit", "amet"], ["lorem", "ipsum", "dolor", "sit", "amet"],
ConsistentFakeEmbeddings(), ConsistentFakeEmbeddings(),
collection_name=collection_name, collection_name=collection_name,
vector_name=vector_name, vector_name=vector_name,
location=qdrant_location,
) )
client = QdrantClient() client = vec_store.client
assert 5 == client.count(collection_name).count assert 5 == client.count(collection_name).count
assert all( assert all(
vector_name in point.vector # type: ignore[operator] vector_name in point.vector # type: ignore[operator]
@ -91,12 +89,11 @@ async def test_qdrant_from_texts_stores_embeddings_as_named_vectors(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "custom-vector"]) @pytest.mark.parametrize("vector_name", [None, "custom-vector"])
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
async def test_qdrant_from_texts_reuses_same_collection( async def test_qdrant_from_texts_reuses_same_collection(
vector_name: Optional[str], vector_name: Optional[str],
) -> None: ) -> None:
"""Test if Qdrant.afrom_texts reuses the same collection""" """Test if Qdrant.afrom_texts reuses the same collection"""
from qdrant_client import QdrantClient
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
embeddings = ConsistentFakeEmbeddings() embeddings = ConsistentFakeEmbeddings()
@ -107,19 +104,20 @@ async def test_qdrant_from_texts_reuses_same_collection(
vector_name=vector_name, vector_name=vector_name,
) )
await Qdrant.afrom_texts( vec_store = await Qdrant.afrom_texts(
["foo", "bar"], ["foo", "bar"],
embeddings, embeddings,
collection_name=collection_name, collection_name=collection_name,
vector_name=vector_name, vector_name=vector_name,
) )
client = QdrantClient() client = vec_store.client
assert 7 == client.count(collection_name).count assert 7 == client.count(collection_name).count
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "custom-vector"]) @pytest.mark.parametrize("vector_name", [None, "custom-vector"])
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
async def test_qdrant_from_texts_raises_error_on_different_dimensionality( async def test_qdrant_from_texts_raises_error_on_different_dimensionality(
vector_name: Optional[str], vector_name: Optional[str],
) -> None: ) -> None:
@ -152,6 +150,7 @@ async def test_qdrant_from_texts_raises_error_on_different_dimensionality(
("my-first-vector", "my-second_vector"), ("my-first-vector", "my-second_vector"),
], ],
) )
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
async def test_qdrant_from_texts_raises_error_on_different_vector_name( async def test_qdrant_from_texts_raises_error_on_different_vector_name(
first_vector_name: Optional[str], first_vector_name: Optional[str],
second_vector_name: Optional[str], second_vector_name: Optional[str],
@ -176,6 +175,7 @@ async def test_qdrant_from_texts_raises_error_on_different_vector_name(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
async def test_qdrant_from_texts_raises_error_on_different_distance() -> None: async def test_qdrant_from_texts_raises_error_on_different_distance() -> None:
"""Test if Qdrant.afrom_texts raises an exception if distance does not match""" """Test if Qdrant.afrom_texts raises an exception if distance does not match"""
collection_name = uuid.uuid4().hex collection_name = uuid.uuid4().hex
@ -198,6 +198,7 @@ async def test_qdrant_from_texts_raises_error_on_different_distance() -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "custom-vector"]) @pytest.mark.parametrize("vector_name", [None, "custom-vector"])
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
async def test_qdrant_from_texts_recreates_collection_on_force_recreate( async def test_qdrant_from_texts_recreates_collection_on_force_recreate(
vector_name: Optional[str], vector_name: Optional[str],
) -> None: ) -> None:
@ -229,8 +230,12 @@ async def test_qdrant_from_texts_recreates_collection_on_force_recreate(
@pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) @pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) @pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_from_texts_stores_metadatas( async def test_qdrant_from_texts_stores_metadatas(
batch_size: int, content_payload_key: str, metadata_payload_key: str batch_size: int,
content_payload_key: str,
metadata_payload_key: str,
qdrant_location: str,
) -> None: ) -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -242,6 +247,7 @@ async def test_qdrant_from_texts_stores_metadatas(
content_payload_key=content_payload_key, content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key, metadata_payload_key=metadata_payload_key,
batch_size=batch_size, batch_size=batch_size,
location=qdrant_location,
) )
output = await docsearch.asimilarity_search("foo", k=1) output = await docsearch.asimilarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": 0})] assert output == [Document(page_content="foo", metadata={"page": 0})]

View File

@ -7,12 +7,8 @@ from langchain.vectorstores import Qdrant
from tests.integration_tests.vectorstores.fake_embeddings import ( from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings, ConsistentFakeEmbeddings,
) )
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import (
from .common import qdrant_is_not_running qdrant_locations,
# Skipping all the tests in the module if Qdrant is not running on localhost.
pytestmark = pytest.mark.skipif(
qdrant_is_not_running(), reason="Qdrant server is not running"
) )
@ -21,11 +17,13 @@ pytestmark = pytest.mark.skipif(
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"]) @pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"]) @pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_max_marginal_relevance_search( async def test_qdrant_max_marginal_relevance_search(
batch_size: int, batch_size: int,
content_payload_key: str, content_payload_key: str,
metadata_payload_key: str, metadata_payload_key: str,
vector_name: Optional[str], vector_name: Optional[str],
qdrant_location: str,
) -> None: ) -> None:
"""Test end to end construction and MRR search.""" """Test end to end construction and MRR search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -38,8 +36,12 @@ async def test_qdrant_max_marginal_relevance_search(
metadata_payload_key=metadata_payload_key, metadata_payload_key=metadata_payload_key,
batch_size=batch_size, batch_size=batch_size,
vector_name=vector_name, vector_name=vector_name,
location=qdrant_location,
distance_func="EUCLID", # Euclid distance used to avoid normalization
)
output = await docsearch.amax_marginal_relevance_search(
"foo", k=2, fetch_k=3, lambda_mult=0.0
) )
output = await docsearch.amax_marginal_relevance_search("foo", k=2, fetch_k=3)
assert output == [ assert output == [
Document(page_content="foo", metadata={"page": 0}), Document(page_content="foo", metadata={"page": 0}),
Document(page_content="baz", metadata={"page": 2}), Document(page_content="baz", metadata={"page": 2}),

View File

@ -9,12 +9,8 @@ from langchain.vectorstores import Qdrant
from tests.integration_tests.vectorstores.fake_embeddings import ( from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings, ConsistentFakeEmbeddings,
) )
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import (
from .common import qdrant_is_not_running qdrant_locations,
# Skipping all the tests in the module if Qdrant is not running on localhost.
pytestmark = pytest.mark.skipif(
qdrant_is_not_running(), reason="Qdrant server is not running"
) )
@ -23,11 +19,13 @@ pytestmark = pytest.mark.skipif(
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) @pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) @pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search( async def test_qdrant_similarity_search(
batch_size: int, batch_size: int,
content_payload_key: str, content_payload_key: str,
metadata_payload_key: str, metadata_payload_key: str,
vector_name: Optional[str], vector_name: Optional[str],
qdrant_location: str,
) -> None: ) -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -38,6 +36,7 @@ async def test_qdrant_similarity_search(
metadata_payload_key=metadata_payload_key, metadata_payload_key=metadata_payload_key,
batch_size=batch_size, batch_size=batch_size,
vector_name=vector_name, vector_name=vector_name,
location=qdrant_location,
) )
output = await docsearch.asimilarity_search("foo", k=1) output = await docsearch.asimilarity_search("foo", k=1)
assert output == [Document(page_content="foo")] assert output == [Document(page_content="foo")]
@ -48,11 +47,13 @@ async def test_qdrant_similarity_search(
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) @pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) @pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_by_vector( async def test_qdrant_similarity_search_by_vector(
batch_size: int, batch_size: int,
content_payload_key: str, content_payload_key: str,
metadata_payload_key: str, metadata_payload_key: str,
vector_name: Optional[str], vector_name: Optional[str],
qdrant_location: str,
) -> None: ) -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -63,6 +64,7 @@ async def test_qdrant_similarity_search_by_vector(
metadata_payload_key=metadata_payload_key, metadata_payload_key=metadata_payload_key,
batch_size=batch_size, batch_size=batch_size,
vector_name=vector_name, vector_name=vector_name,
location=qdrant_location,
) )
embeddings = ConsistentFakeEmbeddings().embed_query("foo") embeddings = ConsistentFakeEmbeddings().embed_query("foo")
output = await docsearch.asimilarity_search_by_vector(embeddings, k=1) output = await docsearch.asimilarity_search_by_vector(embeddings, k=1)
@ -74,11 +76,13 @@ async def test_qdrant_similarity_search_by_vector(
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) @pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) @pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_with_score_by_vector( async def test_qdrant_similarity_search_with_score_by_vector(
batch_size: int, batch_size: int,
content_payload_key: str, content_payload_key: str,
metadata_payload_key: str, metadata_payload_key: str,
vector_name: Optional[str], vector_name: Optional[str],
qdrant_location: str,
) -> None: ) -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -89,6 +93,7 @@ async def test_qdrant_similarity_search_with_score_by_vector(
metadata_payload_key=metadata_payload_key, metadata_payload_key=metadata_payload_key,
batch_size=batch_size, batch_size=batch_size,
vector_name=vector_name, vector_name=vector_name,
location=qdrant_location,
) )
embeddings = ConsistentFakeEmbeddings().embed_query("foo") embeddings = ConsistentFakeEmbeddings().embed_query("foo")
output = await docsearch.asimilarity_search_with_score_by_vector(embeddings, k=1) output = await docsearch.asimilarity_search_with_score_by_vector(embeddings, k=1)
@ -101,8 +106,9 @@ async def test_qdrant_similarity_search_with_score_by_vector(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_filters( async def test_qdrant_similarity_search_filters(
batch_size: int, vector_name: Optional[str] batch_size: int, vector_name: Optional[str], qdrant_location: str
) -> None: ) -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -116,6 +122,7 @@ async def test_qdrant_similarity_search_filters(
metadatas=metadatas, metadatas=metadatas,
batch_size=batch_size, batch_size=batch_size,
vector_name=vector_name, vector_name=vector_name,
location=qdrant_location,
) )
output = await docsearch.asimilarity_search( output = await docsearch.asimilarity_search(
@ -131,8 +138,10 @@ async def test_qdrant_similarity_search_filters(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_with_relevance_score_no_threshold( async def test_qdrant_similarity_search_with_relevance_score_no_threshold(
vector_name: Optional[str], vector_name: Optional[str],
qdrant_location: str,
) -> None: ) -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -145,6 +154,7 @@ async def test_qdrant_similarity_search_with_relevance_score_no_threshold(
ConsistentFakeEmbeddings(), ConsistentFakeEmbeddings(),
metadatas=metadatas, metadatas=metadatas,
vector_name=vector_name, vector_name=vector_name,
location=qdrant_location,
) )
output = await docsearch.asimilarity_search_with_relevance_scores( output = await docsearch.asimilarity_search_with_relevance_scores(
"foo", k=3, score_threshold=None "foo", k=3, score_threshold=None
@ -157,8 +167,10 @@ async def test_qdrant_similarity_search_with_relevance_score_no_threshold(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_with_relevance_score_with_threshold( async def test_qdrant_similarity_search_with_relevance_score_with_threshold(
vector_name: Optional[str], vector_name: Optional[str],
qdrant_location: str,
) -> None: ) -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -171,6 +183,7 @@ async def test_qdrant_similarity_search_with_relevance_score_with_threshold(
ConsistentFakeEmbeddings(), ConsistentFakeEmbeddings(),
metadatas=metadatas, metadatas=metadatas,
vector_name=vector_name, vector_name=vector_name,
location=qdrant_location,
) )
score_threshold = 0.98 score_threshold = 0.98
@ -184,8 +197,10 @@ async def test_qdrant_similarity_search_with_relevance_score_with_threshold(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_similarity_search_with_relevance_score_with_threshold_and_filter( async def test_similarity_search_with_relevance_score_with_threshold_and_filter(
vector_name: Optional[str], vector_name: Optional[str],
qdrant_location: str,
) -> None: ) -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -198,6 +213,7 @@ async def test_similarity_search_with_relevance_score_with_threshold_and_filter(
ConsistentFakeEmbeddings(), ConsistentFakeEmbeddings(),
metadatas=metadatas, metadatas=metadatas,
vector_name=vector_name, vector_name=vector_name,
location=qdrant_location,
) )
score_threshold = 0.99 # for almost exact match score_threshold = 0.99 # for almost exact match
# test negative filter condition # test negative filter condition
@ -217,8 +233,10 @@ async def test_similarity_search_with_relevance_score_with_threshold_and_filter(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_filters_with_qdrant_filters( async def test_qdrant_similarity_search_filters_with_qdrant_filters(
vector_name: Optional[str], vector_name: Optional[str],
qdrant_location: str,
) -> None: ) -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -231,6 +249,7 @@ async def test_qdrant_similarity_search_filters_with_qdrant_filters(
ConsistentFakeEmbeddings(), ConsistentFakeEmbeddings(),
metadatas=metadatas, metadatas=metadatas,
vector_name=vector_name, vector_name=vector_name,
location=qdrant_location,
) )
qdrant_filter = rest.Filter( qdrant_filter = rest.Filter(
@ -263,11 +282,13 @@ async def test_qdrant_similarity_search_filters_with_qdrant_filters(
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) @pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) @pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_with_relevance_scores( async def test_qdrant_similarity_search_with_relevance_scores(
batch_size: int, batch_size: int,
content_payload_key: str, content_payload_key: str,
metadata_payload_key: str, metadata_payload_key: str,
vector_name: str, vector_name: str,
qdrant_location: str,
) -> None: ) -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -278,6 +299,7 @@ async def test_qdrant_similarity_search_with_relevance_scores(
metadata_payload_key=metadata_payload_key, metadata_payload_key=metadata_payload_key,
batch_size=batch_size, batch_size=batch_size,
vector_name=vector_name, vector_name=vector_name,
location=qdrant_location,
) )
output = await docsearch.asimilarity_search_with_relevance_scores("foo", k=3) output = await docsearch.asimilarity_search_with_relevance_scores("foo", k=3)

View File

@ -13,7 +13,6 @@ from tests.integration_tests.vectorstores.fake_embeddings import (
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"]) @pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"]) @pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.skip(reason="Qdrant local behaves differently from Qdrant server")
def test_qdrant_max_marginal_relevance_search( def test_qdrant_max_marginal_relevance_search(
batch_size: int, batch_size: int,
content_payload_key: str, content_payload_key: str,
@ -32,8 +31,11 @@ def test_qdrant_max_marginal_relevance_search(
metadata_payload_key=metadata_payload_key, metadata_payload_key=metadata_payload_key,
batch_size=batch_size, batch_size=batch_size,
vector_name=vector_name, vector_name=vector_name,
distance_func="EUCLID", # Euclid distance used to avoid normalization
)
output = docsearch.max_marginal_relevance_search(
"foo", k=2, fetch_k=3, lambda_mult=0.0
) )
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
assert output == [ assert output == [
Document(page_content="foo", metadata={"page": 0}), Document(page_content="foo", metadata={"page": 0}),
Document(page_content="baz", metadata={"page": 2}), Document(page_content="baz", metadata={"page": 2}),