From ed6a5532ac4fffa2d7910ab034ec6965c93c8cb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20=C5=81ukawski?= Date: Fri, 21 Jul 2023 04:04:33 +0200 Subject: [PATCH] Implement async support in Qdrant local mode (#8001) I've extended the support of async API to local Qdrant mode. It is faked but allows prototyping without spinning a container. The tests are improved to test the in-memory case as well. @baskaryan @rlancemartin @eyurtsev @agola11 --- .../how_to/self_query/qdrant_self_query.ipynb | 1 - langchain/vectorstores/qdrant.py | 120 ++++++++++++++---- .../qdrant/{async => async_api}/__init__.py | 0 .../vectorstores/qdrant/async_api/fixtures.py | 13 ++ .../{async => async_api}/test_add_texts.py | 31 +++-- .../{async => async_api}/test_from_texts.py | 56 ++++---- .../test_max_marginal_relevance.py | 16 ++- .../test_similarity_search.py | 36 +++++- .../vectorstores/qdrant/{async => }/common.py | 0 .../qdrant/test_max_marginal_relevance.py | 6 +- 10 files changed, 196 insertions(+), 83 deletions(-) rename tests/integration_tests/vectorstores/qdrant/{async => async_api}/__init__.py (100%) create mode 100644 tests/integration_tests/vectorstores/qdrant/async_api/fixtures.py rename tests/integration_tests/vectorstores/qdrant/{async => async_api}/test_add_texts.py (80%) rename tests/integration_tests/vectorstores/qdrant/{async => async_api}/test_from_texts.py (81%) rename tests/integration_tests/vectorstores/qdrant/{async => async_api}/test_max_marginal_relevance.py (75%) rename tests/integration_tests/vectorstores/qdrant/{async => async_api}/test_similarity_search.py (88%) rename tests/integration_tests/vectorstores/qdrant/{async => }/common.py (100%) diff --git a/docs/extras/modules/data_connection/retrievers/how_to/self_query/qdrant_self_query.ipynb b/docs/extras/modules/data_connection/retrievers/how_to/self_query/qdrant_self_query.ipynb index a23908135d9..70afecac7d1 100644 --- a/docs/extras/modules/data_connection/retrievers/how_to/self_query/qdrant_self_query.ipynb +++ b/docs/extras/modules/data_connection/retrievers/how_to/self_query/qdrant_self_query.ipynb @@ -114,7 +114,6 @@ " \"rating\": 9.9,\n", " \"director\": \"Andrei Tarkovsky\",\n", " \"genre\": \"science fiction\",\n", - " \"rating\": 9.9,\n", " },\n", " ),\n", "]\n", diff --git a/langchain/vectorstores/qdrant.py b/langchain/vectorstores/qdrant.py index e0cba41ee78..9ce999578ba 100644 --- a/langchain/vectorstores/qdrant.py +++ b/langchain/vectorstores/qdrant.py @@ -1,6 +1,8 @@ """Wrapper around Qdrant vector database.""" from __future__ import annotations +import asyncio +import functools import uuid import warnings from itertools import islice @@ -40,6 +42,30 @@ class QdrantException(Exception): """Base class for all the Qdrant related exceptions""" +def sync_call_fallback(method: Callable) -> Callable: + """ + Decorator to call the synchronous method of the class if the async method is not + implemented. This decorator might be only used for the methods that are defined + as async in the class. + """ + + @functools.wraps(method) + async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + try: + return await method(self, *args, **kwargs) + except NotImplementedError: + # If the async method is not implemented, call the synchronous method + # by removing the first letter from the method name. For example, + # if the async method is called ``aaad_texts``, the synchronous method + # will be called ``aad_texts``. + sync_method = functools.partial( + getattr(self, method.__name__[1:]), *args, **kwargs + ) + return await asyncio.get_event_loop().run_in_executor(None, sync_method) + + return wrapper + + class Qdrant(VectorStore): """Wrapper around Qdrant vector database. @@ -155,6 +181,7 @@ class Qdrant(VectorStore): return added_ids + @sync_call_fallback async def aadd_texts( self, texts: Iterable[str], @@ -250,6 +277,7 @@ class Qdrant(VectorStore): ) return list(map(itemgetter(0), results)) + @sync_call_fallback async def asimilarity_search( self, query: str, @@ -322,6 +350,7 @@ class Qdrant(VectorStore): **kwargs, ) + @sync_call_fallback async def asimilarity_search_with_score( self, query: str, @@ -431,6 +460,7 @@ class Qdrant(VectorStore): ) return list(map(itemgetter(0), results)) + @sync_call_fallback async def asimilarity_search_by_vector( self, embedding: List[float], @@ -567,6 +597,7 @@ class Qdrant(VectorStore): for result in results ] + @sync_call_fallback async def asimilarity_search_with_score_by_vector( self, embedding: List[float], @@ -685,6 +716,7 @@ class Qdrant(VectorStore): query_embedding, k, fetch_k, lambda_mult, **kwargs ) + @sync_call_fallback async def amax_marginal_relevance_search( self, query: str, @@ -739,33 +771,12 @@ class Qdrant(VectorStore): Returns: List of Documents selected by maximal marginal relevance. """ - query_vector = embedding - if self.vector_name is not None: - query_vector = (self.vector_name, query_vector) # type: ignore[assignment] - - results = self.client.search( - collection_name=self.collection_name, - query_vector=query_vector, - with_payload=True, - with_vectors=True, - limit=fetch_k, + results = self.max_marginal_relevance_search_with_score_by_vector( + embedding=embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs ) - embeddings = [ - result.vector.get(self.vector_name) # type: ignore[index, union-attr] - if self.vector_name is not None - else result.vector - for result in results - ] - mmr_selected = maximal_marginal_relevance( - np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult - ) - return [ - self._document_from_scored_point( - results[i], self.content_payload_key, self.metadata_payload_key - ) - for i in mmr_selected - ] + return list(map(itemgetter(0), results)) + @sync_call_fallback async def amax_marginal_relevance_search_by_vector( self, embedding: List[float], @@ -795,6 +806,61 @@ class Qdrant(VectorStore): ) return list(map(itemgetter(0), results)) + def max_marginal_relevance_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs selected using the maximal marginal relevance. + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + Returns: + List of Documents selected by maximal marginal relevance and distance for + each. + """ + query_vector = embedding + if self.vector_name is not None: + query_vector = (self.vector_name, query_vector) # type: ignore[assignment] + + results = self.client.search( + collection_name=self.collection_name, + query_vector=query_vector, + with_payload=True, + with_vectors=True, + limit=fetch_k, + ) + embeddings = [ + result.vector.get(self.vector_name) # type: ignore[index, union-attr] + if self.vector_name is not None + else result.vector + for result in results + ] + mmr_selected = maximal_marginal_relevance( + np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult + ) + return [ + ( + self._document_from_scored_point( + results[i], self.content_payload_key, self.metadata_payload_key + ), + results[i].score, + ) + for i in mmr_selected + ] + + @sync_call_fallback async def amax_marginal_relevance_search_with_score_by_vector( self, embedding: List[float], @@ -1038,7 +1104,6 @@ class Qdrant(VectorStore): content_payload_key, metadata_payload_key, vector_name, - batch_size, shard_number, replication_factor, write_consistency_factor, @@ -1055,6 +1120,7 @@ class Qdrant(VectorStore): return qdrant @classmethod + @sync_call_fallback async def afrom_texts( cls: Type[Qdrant], texts: List[str], @@ -1214,7 +1280,6 @@ class Qdrant(VectorStore): content_payload_key, metadata_payload_key, vector_name, - batch_size, shard_number, replication_factor, write_consistency_factor, @@ -1253,7 +1318,6 @@ class Qdrant(VectorStore): content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, vector_name: Optional[str] = VECTOR_NAME, - batch_size: int = 64, shard_number: Optional[int] = None, replication_factor: Optional[int] = None, write_consistency_factor: Optional[int] = None, diff --git a/tests/integration_tests/vectorstores/qdrant/async/__init__.py b/tests/integration_tests/vectorstores/qdrant/async_api/__init__.py similarity index 100% rename from tests/integration_tests/vectorstores/qdrant/async/__init__.py rename to tests/integration_tests/vectorstores/qdrant/async_api/__init__.py diff --git a/tests/integration_tests/vectorstores/qdrant/async_api/fixtures.py b/tests/integration_tests/vectorstores/qdrant/async_api/fixtures.py new file mode 100644 index 00000000000..2728f4f683a --- /dev/null +++ b/tests/integration_tests/vectorstores/qdrant/async_api/fixtures.py @@ -0,0 +1,13 @@ +import logging +from typing import List + +from tests.integration_tests.vectorstores.qdrant.common import qdrant_is_not_running + +logger = logging.getLogger(__name__) + + +def qdrant_locations() -> List[str]: + if qdrant_is_not_running(): + logger.warning("Running Qdrant async tests in memory mode only.") + return [":memory:"] + return ["http://localhost:6333", ":memory:"] diff --git a/tests/integration_tests/vectorstores/qdrant/async/test_add_texts.py b/tests/integration_tests/vectorstores/qdrant/async_api/test_add_texts.py similarity index 80% rename from tests/integration_tests/vectorstores/qdrant/async/test_add_texts.py rename to tests/integration_tests/vectorstores/qdrant/async_api/test_add_texts.py index ddc7ccef857..1de598d4757 100644 --- a/tests/integration_tests/vectorstores/qdrant/async/test_add_texts.py +++ b/tests/integration_tests/vectorstores/qdrant/async_api/test_add_texts.py @@ -7,23 +7,23 @@ from langchain.vectorstores import Qdrant from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, ) - -from .common import qdrant_is_not_running - -# Skipping all the tests in the module if Qdrant is not running on localhost. -pytestmark = pytest.mark.skipif( - qdrant_is_not_running(), reason="Qdrant server is not running" +from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import ( # noqa + qdrant_locations, ) @pytest.mark.asyncio @pytest.mark.parametrize("batch_size", [1, 64]) -async def test_qdrant_aadd_texts_returns_all_ids(batch_size: int) -> None: +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) +async def test_qdrant_aadd_texts_returns_all_ids( + batch_size: int, qdrant_location: str +) -> None: """Test end to end Qdrant.aadd_texts returns unique ids.""" docsearch: Qdrant = Qdrant.from_texts( ["foobar"], ConsistentFakeEmbeddings(), batch_size=batch_size, + location=qdrant_location, ) ids = await docsearch.aadd_texts(["foo", "bar", "baz"]) @@ -33,14 +33,15 @@ async def test_qdrant_aadd_texts_returns_all_ids(batch_size: int) -> None: @pytest.mark.asyncio @pytest.mark.parametrize("vector_name", [None, "my-vector"]) +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_aadd_texts_stores_duplicated_texts( - vector_name: Optional[str], + vector_name: Optional[str], qdrant_location: str ) -> None: """Test end to end Qdrant.aadd_texts stores duplicated texts separately.""" from qdrant_client import QdrantClient from qdrant_client.http import models as rest - client = QdrantClient() + client = QdrantClient(location=qdrant_location) collection_name = "test" vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE) if vector_name is not None: @@ -61,7 +62,10 @@ async def test_qdrant_aadd_texts_stores_duplicated_texts( @pytest.mark.asyncio @pytest.mark.parametrize("batch_size", [1, 64]) -async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None: +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) +async def test_qdrant_aadd_texts_stores_ids( + batch_size: int, qdrant_location: str +) -> None: """Test end to end Qdrant.aadd_texts stores provided ids.""" from qdrant_client import QdrantClient @@ -70,7 +74,7 @@ async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None: "cdc1aa36-d6ab-4fb2-8a94-56674fd27484", ] - client = QdrantClient() + client = QdrantClient(location=qdrant_location) collection_name = "test" client.recreate_collection( collection_name, @@ -90,15 +94,16 @@ async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None: @pytest.mark.asyncio @pytest.mark.parametrize("vector_name", ["custom-vector"]) +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_aadd_texts_stores_embeddings_as_named_vectors( - vector_name: str, + vector_name: str, qdrant_location: str ) -> None: """Test end to end Qdrant.aadd_texts stores named vectors if name is provided.""" from qdrant_client import QdrantClient collection_name = "test" - client = QdrantClient() + client = QdrantClient(location=qdrant_location) client.recreate_collection( collection_name, vectors_config={ diff --git a/tests/integration_tests/vectorstores/qdrant/async/test_from_texts.py b/tests/integration_tests/vectorstores/qdrant/async_api/test_from_texts.py similarity index 81% rename from tests/integration_tests/vectorstores/qdrant/async/test_from_texts.py rename to tests/integration_tests/vectorstores/qdrant/async_api/test_from_texts.py index 4aac43e69fa..0240c691ce4 100644 --- a/tests/integration_tests/vectorstores/qdrant/async/test_from_texts.py +++ b/tests/integration_tests/vectorstores/qdrant/async_api/test_from_texts.py @@ -9,56 +9,53 @@ from langchain.vectorstores.qdrant import QdrantException from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, ) - -from .common import qdrant_is_not_running - -# Skipping all the tests in the module if Qdrant is not running on localhost. -pytestmark = pytest.mark.skipif( - qdrant_is_not_running(), reason="Qdrant server is not running" +from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import ( + qdrant_locations, ) +from tests.integration_tests.vectorstores.qdrant.common import qdrant_is_not_running @pytest.mark.asyncio -async def test_qdrant_from_texts_stores_duplicated_texts() -> None: +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) +async def test_qdrant_from_texts_stores_duplicated_texts(qdrant_location: str) -> None: """Test end to end Qdrant.afrom_texts stores duplicated texts separately.""" - from qdrant_client import QdrantClient - collection_name = uuid.uuid4().hex - await Qdrant.afrom_texts( + vec_store = await Qdrant.afrom_texts( ["abc", "abc"], ConsistentFakeEmbeddings(), collection_name=collection_name, + location=qdrant_location, ) - client = QdrantClient() + client = vec_store.client assert 2 == client.count(collection_name).count @pytest.mark.asyncio @pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("vector_name", [None, "my-vector"]) +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_from_texts_stores_ids( - batch_size: int, vector_name: Optional[str] + batch_size: int, vector_name: Optional[str], qdrant_location: str ) -> None: """Test end to end Qdrant.afrom_texts stores provided ids.""" - from qdrant_client import QdrantClient - collection_name = uuid.uuid4().hex ids = [ "fa38d572-4c31-4579-aedc-1960d79df6df", "cdc1aa36-d6ab-4fb2-8a94-56674fd27484", ] - await Qdrant.afrom_texts( + vec_store = await Qdrant.afrom_texts( ["abc", "def"], ConsistentFakeEmbeddings(), ids=ids, collection_name=collection_name, batch_size=batch_size, vector_name=vector_name, + location=qdrant_location, ) - client = QdrantClient() + client = vec_store.client assert 2 == client.count(collection_name).count stored_ids = [point.id for point in client.scroll(collection_name)[0]] assert set(ids) == set(stored_ids) @@ -66,22 +63,23 @@ async def test_qdrant_from_texts_stores_ids( @pytest.mark.asyncio @pytest.mark.parametrize("vector_name", ["custom-vector"]) +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_from_texts_stores_embeddings_as_named_vectors( vector_name: str, + qdrant_location: str, ) -> None: """Test end to end Qdrant.afrom_texts stores named vectors if name is provided.""" - from qdrant_client import QdrantClient - collection_name = uuid.uuid4().hex - await Qdrant.afrom_texts( + vec_store = await Qdrant.afrom_texts( ["lorem", "ipsum", "dolor", "sit", "amet"], ConsistentFakeEmbeddings(), collection_name=collection_name, vector_name=vector_name, + location=qdrant_location, ) - client = QdrantClient() + client = vec_store.client assert 5 == client.count(collection_name).count assert all( vector_name in point.vector # type: ignore[operator] @@ -91,12 +89,11 @@ async def test_qdrant_from_texts_stores_embeddings_as_named_vectors( @pytest.mark.asyncio @pytest.mark.parametrize("vector_name", [None, "custom-vector"]) +@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running") async def test_qdrant_from_texts_reuses_same_collection( vector_name: Optional[str], ) -> None: """Test if Qdrant.afrom_texts reuses the same collection""" - from qdrant_client import QdrantClient - collection_name = uuid.uuid4().hex embeddings = ConsistentFakeEmbeddings() @@ -107,19 +104,20 @@ async def test_qdrant_from_texts_reuses_same_collection( vector_name=vector_name, ) - await Qdrant.afrom_texts( + vec_store = await Qdrant.afrom_texts( ["foo", "bar"], embeddings, collection_name=collection_name, vector_name=vector_name, ) - client = QdrantClient() + client = vec_store.client assert 7 == client.count(collection_name).count @pytest.mark.asyncio @pytest.mark.parametrize("vector_name", [None, "custom-vector"]) +@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running") async def test_qdrant_from_texts_raises_error_on_different_dimensionality( vector_name: Optional[str], ) -> None: @@ -152,6 +150,7 @@ async def test_qdrant_from_texts_raises_error_on_different_dimensionality( ("my-first-vector", "my-second_vector"), ], ) +@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running") async def test_qdrant_from_texts_raises_error_on_different_vector_name( first_vector_name: Optional[str], second_vector_name: Optional[str], @@ -176,6 +175,7 @@ async def test_qdrant_from_texts_raises_error_on_different_vector_name( @pytest.mark.asyncio +@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running") async def test_qdrant_from_texts_raises_error_on_different_distance() -> None: """Test if Qdrant.afrom_texts raises an exception if distance does not match""" collection_name = uuid.uuid4().hex @@ -198,6 +198,7 @@ async def test_qdrant_from_texts_raises_error_on_different_distance() -> None: @pytest.mark.asyncio @pytest.mark.parametrize("vector_name", [None, "custom-vector"]) +@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running") async def test_qdrant_from_texts_recreates_collection_on_force_recreate( vector_name: Optional[str], ) -> None: @@ -229,8 +230,12 @@ async def test_qdrant_from_texts_recreates_collection_on_force_recreate( @pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) @pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_from_texts_stores_metadatas( - batch_size: int, content_payload_key: str, metadata_payload_key: str + batch_size: int, + content_payload_key: str, + metadata_payload_key: str, + qdrant_location: str, ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -242,6 +247,7 @@ async def test_qdrant_from_texts_stores_metadatas( content_payload_key=content_payload_key, metadata_payload_key=metadata_payload_key, batch_size=batch_size, + location=qdrant_location, ) output = await docsearch.asimilarity_search("foo", k=1) assert output == [Document(page_content="foo", metadata={"page": 0})] diff --git a/tests/integration_tests/vectorstores/qdrant/async/test_max_marginal_relevance.py b/tests/integration_tests/vectorstores/qdrant/async_api/test_max_marginal_relevance.py similarity index 75% rename from tests/integration_tests/vectorstores/qdrant/async/test_max_marginal_relevance.py rename to tests/integration_tests/vectorstores/qdrant/async_api/test_max_marginal_relevance.py index b40f65e126a..2784b4c424b 100644 --- a/tests/integration_tests/vectorstores/qdrant/async/test_max_marginal_relevance.py +++ b/tests/integration_tests/vectorstores/qdrant/async_api/test_max_marginal_relevance.py @@ -7,12 +7,8 @@ from langchain.vectorstores import Qdrant from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, ) - -from .common import qdrant_is_not_running - -# Skipping all the tests in the module if Qdrant is not running on localhost. -pytestmark = pytest.mark.skipif( - qdrant_is_not_running(), reason="Qdrant server is not running" +from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import ( + qdrant_locations, ) @@ -21,11 +17,13 @@ pytestmark = pytest.mark.skipif( @pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"]) @pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"]) +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_max_marginal_relevance_search( batch_size: int, content_payload_key: str, metadata_payload_key: str, vector_name: Optional[str], + qdrant_location: str, ) -> None: """Test end to end construction and MRR search.""" texts = ["foo", "bar", "baz"] @@ -38,8 +36,12 @@ async def test_qdrant_max_marginal_relevance_search( metadata_payload_key=metadata_payload_key, batch_size=batch_size, vector_name=vector_name, + location=qdrant_location, + distance_func="EUCLID", # Euclid distance used to avoid normalization + ) + output = await docsearch.amax_marginal_relevance_search( + "foo", k=2, fetch_k=3, lambda_mult=0.0 ) - output = await docsearch.amax_marginal_relevance_search("foo", k=2, fetch_k=3) assert output == [ Document(page_content="foo", metadata={"page": 0}), Document(page_content="baz", metadata={"page": 2}), diff --git a/tests/integration_tests/vectorstores/qdrant/async/test_similarity_search.py b/tests/integration_tests/vectorstores/qdrant/async_api/test_similarity_search.py similarity index 88% rename from tests/integration_tests/vectorstores/qdrant/async/test_similarity_search.py rename to tests/integration_tests/vectorstores/qdrant/async_api/test_similarity_search.py index 7833d5ead9b..c6158bbb50d 100644 --- a/tests/integration_tests/vectorstores/qdrant/async/test_similarity_search.py +++ b/tests/integration_tests/vectorstores/qdrant/async_api/test_similarity_search.py @@ -9,12 +9,8 @@ from langchain.vectorstores import Qdrant from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, ) - -from .common import qdrant_is_not_running - -# Skipping all the tests in the module if Qdrant is not running on localhost. -pytestmark = pytest.mark.skipif( - qdrant_is_not_running(), reason="Qdrant server is not running" +from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import ( + qdrant_locations, ) @@ -23,11 +19,13 @@ pytestmark = pytest.mark.skipif( @pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) @pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"]) +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_similarity_search( batch_size: int, content_payload_key: str, metadata_payload_key: str, vector_name: Optional[str], + qdrant_location: str, ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -38,6 +36,7 @@ async def test_qdrant_similarity_search( metadata_payload_key=metadata_payload_key, batch_size=batch_size, vector_name=vector_name, + location=qdrant_location, ) output = await docsearch.asimilarity_search("foo", k=1) assert output == [Document(page_content="foo")] @@ -48,11 +47,13 @@ async def test_qdrant_similarity_search( @pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) @pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"]) +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_similarity_search_by_vector( batch_size: int, content_payload_key: str, metadata_payload_key: str, vector_name: Optional[str], + qdrant_location: str, ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -63,6 +64,7 @@ async def test_qdrant_similarity_search_by_vector( metadata_payload_key=metadata_payload_key, batch_size=batch_size, vector_name=vector_name, + location=qdrant_location, ) embeddings = ConsistentFakeEmbeddings().embed_query("foo") output = await docsearch.asimilarity_search_by_vector(embeddings, k=1) @@ -74,11 +76,13 @@ async def test_qdrant_similarity_search_by_vector( @pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) @pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"]) +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_similarity_search_with_score_by_vector( batch_size: int, content_payload_key: str, metadata_payload_key: str, vector_name: Optional[str], + qdrant_location: str, ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -89,6 +93,7 @@ async def test_qdrant_similarity_search_with_score_by_vector( metadata_payload_key=metadata_payload_key, batch_size=batch_size, vector_name=vector_name, + location=qdrant_location, ) embeddings = ConsistentFakeEmbeddings().embed_query("foo") output = await docsearch.asimilarity_search_with_score_by_vector(embeddings, k=1) @@ -101,8 +106,9 @@ async def test_qdrant_similarity_search_with_score_by_vector( @pytest.mark.asyncio @pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("vector_name", [None, "my-vector"]) +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_similarity_search_filters( - batch_size: int, vector_name: Optional[str] + batch_size: int, vector_name: Optional[str], qdrant_location: str ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -116,6 +122,7 @@ async def test_qdrant_similarity_search_filters( metadatas=metadatas, batch_size=batch_size, vector_name=vector_name, + location=qdrant_location, ) output = await docsearch.asimilarity_search( @@ -131,8 +138,10 @@ async def test_qdrant_similarity_search_filters( @pytest.mark.asyncio @pytest.mark.parametrize("vector_name", [None, "my-vector"]) +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_similarity_search_with_relevance_score_no_threshold( vector_name: Optional[str], + qdrant_location: str, ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -145,6 +154,7 @@ async def test_qdrant_similarity_search_with_relevance_score_no_threshold( ConsistentFakeEmbeddings(), metadatas=metadatas, vector_name=vector_name, + location=qdrant_location, ) output = await docsearch.asimilarity_search_with_relevance_scores( "foo", k=3, score_threshold=None @@ -157,8 +167,10 @@ async def test_qdrant_similarity_search_with_relevance_score_no_threshold( @pytest.mark.asyncio @pytest.mark.parametrize("vector_name", [None, "my-vector"]) +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_similarity_search_with_relevance_score_with_threshold( vector_name: Optional[str], + qdrant_location: str, ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -171,6 +183,7 @@ async def test_qdrant_similarity_search_with_relevance_score_with_threshold( ConsistentFakeEmbeddings(), metadatas=metadatas, vector_name=vector_name, + location=qdrant_location, ) score_threshold = 0.98 @@ -184,8 +197,10 @@ async def test_qdrant_similarity_search_with_relevance_score_with_threshold( @pytest.mark.asyncio @pytest.mark.parametrize("vector_name", [None, "my-vector"]) +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_similarity_search_with_relevance_score_with_threshold_and_filter( vector_name: Optional[str], + qdrant_location: str, ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -198,6 +213,7 @@ async def test_similarity_search_with_relevance_score_with_threshold_and_filter( ConsistentFakeEmbeddings(), metadatas=metadatas, vector_name=vector_name, + location=qdrant_location, ) score_threshold = 0.99 # for almost exact match # test negative filter condition @@ -217,8 +233,10 @@ async def test_similarity_search_with_relevance_score_with_threshold_and_filter( @pytest.mark.asyncio @pytest.mark.parametrize("vector_name", [None, "my-vector"]) +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_similarity_search_filters_with_qdrant_filters( vector_name: Optional[str], + qdrant_location: str, ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -231,6 +249,7 @@ async def test_qdrant_similarity_search_filters_with_qdrant_filters( ConsistentFakeEmbeddings(), metadatas=metadatas, vector_name=vector_name, + location=qdrant_location, ) qdrant_filter = rest.Filter( @@ -263,11 +282,13 @@ async def test_qdrant_similarity_search_filters_with_qdrant_filters( @pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"]) @pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"]) +@pytest.mark.parametrize("qdrant_location", qdrant_locations()) async def test_qdrant_similarity_search_with_relevance_scores( batch_size: int, content_payload_key: str, metadata_payload_key: str, vector_name: str, + qdrant_location: str, ) -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -278,6 +299,7 @@ async def test_qdrant_similarity_search_with_relevance_scores( metadata_payload_key=metadata_payload_key, batch_size=batch_size, vector_name=vector_name, + location=qdrant_location, ) output = await docsearch.asimilarity_search_with_relevance_scores("foo", k=3) diff --git a/tests/integration_tests/vectorstores/qdrant/async/common.py b/tests/integration_tests/vectorstores/qdrant/common.py similarity index 100% rename from tests/integration_tests/vectorstores/qdrant/async/common.py rename to tests/integration_tests/vectorstores/qdrant/common.py diff --git a/tests/integration_tests/vectorstores/qdrant/test_max_marginal_relevance.py b/tests/integration_tests/vectorstores/qdrant/test_max_marginal_relevance.py index 5008dc1c7fc..5a383b36cc4 100644 --- a/tests/integration_tests/vectorstores/qdrant/test_max_marginal_relevance.py +++ b/tests/integration_tests/vectorstores/qdrant/test_max_marginal_relevance.py @@ -13,7 +13,6 @@ from tests.integration_tests.vectorstores.fake_embeddings import ( @pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"]) @pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"]) @pytest.mark.parametrize("vector_name", [None, "my-vector"]) -@pytest.mark.skip(reason="Qdrant local behaves differently from Qdrant server") def test_qdrant_max_marginal_relevance_search( batch_size: int, content_payload_key: str, @@ -32,8 +31,11 @@ def test_qdrant_max_marginal_relevance_search( metadata_payload_key=metadata_payload_key, batch_size=batch_size, vector_name=vector_name, + distance_func="EUCLID", # Euclid distance used to avoid normalization + ) + output = docsearch.max_marginal_relevance_search( + "foo", k=2, fetch_k=3, lambda_mult=0.0 ) - output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3) assert output == [ Document(page_content="foo", metadata={"page": 0}), Document(page_content="baz", metadata={"page": 2}),