Implement async support in Qdrant local mode (#8001)

I've extended the support of async API to local Qdrant mode. It is faked
but allows prototyping without spinning a container. The tests are
improved to test the in-memory case as well.

@baskaryan @rlancemartin @eyurtsev @agola11
This commit is contained in:
Kacper Łukawski
2023-07-21 04:04:33 +02:00
committed by GitHub
parent 7717c24fc4
commit ed6a5532ac
10 changed files with 196 additions and 83 deletions

View File

@@ -0,0 +1,13 @@
import logging
from typing import List
from tests.integration_tests.vectorstores.qdrant.common import qdrant_is_not_running
logger = logging.getLogger(__name__)
def qdrant_locations() -> List[str]:
if qdrant_is_not_running():
logger.warning("Running Qdrant async tests in memory mode only.")
return [":memory:"]
return ["http://localhost:6333", ":memory:"]

View File

@@ -7,23 +7,23 @@ from langchain.vectorstores import Qdrant
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
)
from .common import qdrant_is_not_running
# Skipping all the tests in the module if Qdrant is not running on localhost.
pytestmark = pytest.mark.skipif(
qdrant_is_not_running(), reason="Qdrant server is not running"
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import ( # noqa
qdrant_locations,
)
@pytest.mark.asyncio
@pytest.mark.parametrize("batch_size", [1, 64])
async def test_qdrant_aadd_texts_returns_all_ids(batch_size: int) -> None:
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_aadd_texts_returns_all_ids(
batch_size: int, qdrant_location: str
) -> None:
"""Test end to end Qdrant.aadd_texts returns unique ids."""
docsearch: Qdrant = Qdrant.from_texts(
["foobar"],
ConsistentFakeEmbeddings(),
batch_size=batch_size,
location=qdrant_location,
)
ids = await docsearch.aadd_texts(["foo", "bar", "baz"])
@@ -33,14 +33,15 @@ async def test_qdrant_aadd_texts_returns_all_ids(batch_size: int) -> None:
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_aadd_texts_stores_duplicated_texts(
vector_name: Optional[str],
vector_name: Optional[str], qdrant_location: str
) -> None:
"""Test end to end Qdrant.aadd_texts stores duplicated texts separately."""
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest
client = QdrantClient()
client = QdrantClient(location=qdrant_location)
collection_name = "test"
vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE)
if vector_name is not None:
@@ -61,7 +62,10 @@ async def test_qdrant_aadd_texts_stores_duplicated_texts(
@pytest.mark.asyncio
@pytest.mark.parametrize("batch_size", [1, 64])
async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None:
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_aadd_texts_stores_ids(
batch_size: int, qdrant_location: str
) -> None:
"""Test end to end Qdrant.aadd_texts stores provided ids."""
from qdrant_client import QdrantClient
@@ -70,7 +74,7 @@ async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None:
"cdc1aa36-d6ab-4fb2-8a94-56674fd27484",
]
client = QdrantClient()
client = QdrantClient(location=qdrant_location)
collection_name = "test"
client.recreate_collection(
collection_name,
@@ -90,15 +94,16 @@ async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None:
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", ["custom-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_aadd_texts_stores_embeddings_as_named_vectors(
vector_name: str,
vector_name: str, qdrant_location: str
) -> None:
"""Test end to end Qdrant.aadd_texts stores named vectors if name is provided."""
from qdrant_client import QdrantClient
collection_name = "test"
client = QdrantClient()
client = QdrantClient(location=qdrant_location)
client.recreate_collection(
collection_name,
vectors_config={

View File

@@ -9,56 +9,53 @@ from langchain.vectorstores.qdrant import QdrantException
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
)
from .common import qdrant_is_not_running
# Skipping all the tests in the module if Qdrant is not running on localhost.
pytestmark = pytest.mark.skipif(
qdrant_is_not_running(), reason="Qdrant server is not running"
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import (
qdrant_locations,
)
from tests.integration_tests.vectorstores.qdrant.common import qdrant_is_not_running
@pytest.mark.asyncio
async def test_qdrant_from_texts_stores_duplicated_texts() -> None:
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_from_texts_stores_duplicated_texts(qdrant_location: str) -> None:
"""Test end to end Qdrant.afrom_texts stores duplicated texts separately."""
from qdrant_client import QdrantClient
collection_name = uuid.uuid4().hex
await Qdrant.afrom_texts(
vec_store = await Qdrant.afrom_texts(
["abc", "abc"],
ConsistentFakeEmbeddings(),
collection_name=collection_name,
location=qdrant_location,
)
client = QdrantClient()
client = vec_store.client
assert 2 == client.count(collection_name).count
@pytest.mark.asyncio
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_from_texts_stores_ids(
batch_size: int, vector_name: Optional[str]
batch_size: int, vector_name: Optional[str], qdrant_location: str
) -> None:
"""Test end to end Qdrant.afrom_texts stores provided ids."""
from qdrant_client import QdrantClient
collection_name = uuid.uuid4().hex
ids = [
"fa38d572-4c31-4579-aedc-1960d79df6df",
"cdc1aa36-d6ab-4fb2-8a94-56674fd27484",
]
await Qdrant.afrom_texts(
vec_store = await Qdrant.afrom_texts(
["abc", "def"],
ConsistentFakeEmbeddings(),
ids=ids,
collection_name=collection_name,
batch_size=batch_size,
vector_name=vector_name,
location=qdrant_location,
)
client = QdrantClient()
client = vec_store.client
assert 2 == client.count(collection_name).count
stored_ids = [point.id for point in client.scroll(collection_name)[0]]
assert set(ids) == set(stored_ids)
@@ -66,22 +63,23 @@ async def test_qdrant_from_texts_stores_ids(
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", ["custom-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_from_texts_stores_embeddings_as_named_vectors(
vector_name: str,
qdrant_location: str,
) -> None:
"""Test end to end Qdrant.afrom_texts stores named vectors if name is provided."""
from qdrant_client import QdrantClient
collection_name = uuid.uuid4().hex
await Qdrant.afrom_texts(
vec_store = await Qdrant.afrom_texts(
["lorem", "ipsum", "dolor", "sit", "amet"],
ConsistentFakeEmbeddings(),
collection_name=collection_name,
vector_name=vector_name,
location=qdrant_location,
)
client = QdrantClient()
client = vec_store.client
assert 5 == client.count(collection_name).count
assert all(
vector_name in point.vector # type: ignore[operator]
@@ -91,12 +89,11 @@ async def test_qdrant_from_texts_stores_embeddings_as_named_vectors(
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
async def test_qdrant_from_texts_reuses_same_collection(
vector_name: Optional[str],
) -> None:
"""Test if Qdrant.afrom_texts reuses the same collection"""
from qdrant_client import QdrantClient
collection_name = uuid.uuid4().hex
embeddings = ConsistentFakeEmbeddings()
@@ -107,19 +104,20 @@ async def test_qdrant_from_texts_reuses_same_collection(
vector_name=vector_name,
)
await Qdrant.afrom_texts(
vec_store = await Qdrant.afrom_texts(
["foo", "bar"],
embeddings,
collection_name=collection_name,
vector_name=vector_name,
)
client = QdrantClient()
client = vec_store.client
assert 7 == client.count(collection_name).count
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
async def test_qdrant_from_texts_raises_error_on_different_dimensionality(
vector_name: Optional[str],
) -> None:
@@ -152,6 +150,7 @@ async def test_qdrant_from_texts_raises_error_on_different_dimensionality(
("my-first-vector", "my-second_vector"),
],
)
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
async def test_qdrant_from_texts_raises_error_on_different_vector_name(
first_vector_name: Optional[str],
second_vector_name: Optional[str],
@@ -176,6 +175,7 @@ async def test_qdrant_from_texts_raises_error_on_different_vector_name(
@pytest.mark.asyncio
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
async def test_qdrant_from_texts_raises_error_on_different_distance() -> None:
"""Test if Qdrant.afrom_texts raises an exception if distance does not match"""
collection_name = uuid.uuid4().hex
@@ -198,6 +198,7 @@ async def test_qdrant_from_texts_raises_error_on_different_distance() -> None:
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
async def test_qdrant_from_texts_recreates_collection_on_force_recreate(
vector_name: Optional[str],
) -> None:
@@ -229,8 +230,12 @@ async def test_qdrant_from_texts_recreates_collection_on_force_recreate(
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_from_texts_stores_metadatas(
batch_size: int, content_payload_key: str, metadata_payload_key: str
batch_size: int,
content_payload_key: str,
metadata_payload_key: str,
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@@ -242,6 +247,7 @@ async def test_qdrant_from_texts_stores_metadatas(
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
location=qdrant_location,
)
output = await docsearch.asimilarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": 0})]

View File

@@ -7,12 +7,8 @@ from langchain.vectorstores import Qdrant
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
)
from .common import qdrant_is_not_running
# Skipping all the tests in the module if Qdrant is not running on localhost.
pytestmark = pytest.mark.skipif(
qdrant_is_not_running(), reason="Qdrant server is not running"
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import (
qdrant_locations,
)
@@ -21,11 +17,13 @@ pytestmark = pytest.mark.skipif(
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_max_marginal_relevance_search(
batch_size: int,
content_payload_key: str,
metadata_payload_key: str,
vector_name: Optional[str],
qdrant_location: str,
) -> None:
"""Test end to end construction and MRR search."""
texts = ["foo", "bar", "baz"]
@@ -38,8 +36,12 @@ async def test_qdrant_max_marginal_relevance_search(
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
vector_name=vector_name,
location=qdrant_location,
distance_func="EUCLID", # Euclid distance used to avoid normalization
)
output = await docsearch.amax_marginal_relevance_search(
"foo", k=2, fetch_k=3, lambda_mult=0.0
)
output = await docsearch.amax_marginal_relevance_search("foo", k=2, fetch_k=3)
assert output == [
Document(page_content="foo", metadata={"page": 0}),
Document(page_content="baz", metadata={"page": 2}),

View File

@@ -9,12 +9,8 @@ from langchain.vectorstores import Qdrant
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
)
from .common import qdrant_is_not_running
# Skipping all the tests in the module if Qdrant is not running on localhost.
pytestmark = pytest.mark.skipif(
qdrant_is_not_running(), reason="Qdrant server is not running"
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import (
qdrant_locations,
)
@@ -23,11 +19,13 @@ pytestmark = pytest.mark.skipif(
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search(
batch_size: int,
content_payload_key: str,
metadata_payload_key: str,
vector_name: Optional[str],
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@@ -38,6 +36,7 @@ async def test_qdrant_similarity_search(
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
vector_name=vector_name,
location=qdrant_location,
)
output = await docsearch.asimilarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
@@ -48,11 +47,13 @@ async def test_qdrant_similarity_search(
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_by_vector(
batch_size: int,
content_payload_key: str,
metadata_payload_key: str,
vector_name: Optional[str],
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@@ -63,6 +64,7 @@ async def test_qdrant_similarity_search_by_vector(
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
vector_name=vector_name,
location=qdrant_location,
)
embeddings = ConsistentFakeEmbeddings().embed_query("foo")
output = await docsearch.asimilarity_search_by_vector(embeddings, k=1)
@@ -74,11 +76,13 @@ async def test_qdrant_similarity_search_by_vector(
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_with_score_by_vector(
batch_size: int,
content_payload_key: str,
metadata_payload_key: str,
vector_name: Optional[str],
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@@ -89,6 +93,7 @@ async def test_qdrant_similarity_search_with_score_by_vector(
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
vector_name=vector_name,
location=qdrant_location,
)
embeddings = ConsistentFakeEmbeddings().embed_query("foo")
output = await docsearch.asimilarity_search_with_score_by_vector(embeddings, k=1)
@@ -101,8 +106,9 @@ async def test_qdrant_similarity_search_with_score_by_vector(
@pytest.mark.asyncio
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_filters(
batch_size: int, vector_name: Optional[str]
batch_size: int, vector_name: Optional[str], qdrant_location: str
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@@ -116,6 +122,7 @@ async def test_qdrant_similarity_search_filters(
metadatas=metadatas,
batch_size=batch_size,
vector_name=vector_name,
location=qdrant_location,
)
output = await docsearch.asimilarity_search(
@@ -131,8 +138,10 @@ async def test_qdrant_similarity_search_filters(
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_with_relevance_score_no_threshold(
vector_name: Optional[str],
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@@ -145,6 +154,7 @@ async def test_qdrant_similarity_search_with_relevance_score_no_threshold(
ConsistentFakeEmbeddings(),
metadatas=metadatas,
vector_name=vector_name,
location=qdrant_location,
)
output = await docsearch.asimilarity_search_with_relevance_scores(
"foo", k=3, score_threshold=None
@@ -157,8 +167,10 @@ async def test_qdrant_similarity_search_with_relevance_score_no_threshold(
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_with_relevance_score_with_threshold(
vector_name: Optional[str],
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@@ -171,6 +183,7 @@ async def test_qdrant_similarity_search_with_relevance_score_with_threshold(
ConsistentFakeEmbeddings(),
metadatas=metadatas,
vector_name=vector_name,
location=qdrant_location,
)
score_threshold = 0.98
@@ -184,8 +197,10 @@ async def test_qdrant_similarity_search_with_relevance_score_with_threshold(
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_similarity_search_with_relevance_score_with_threshold_and_filter(
vector_name: Optional[str],
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@@ -198,6 +213,7 @@ async def test_similarity_search_with_relevance_score_with_threshold_and_filter(
ConsistentFakeEmbeddings(),
metadatas=metadatas,
vector_name=vector_name,
location=qdrant_location,
)
score_threshold = 0.99 # for almost exact match
# test negative filter condition
@@ -217,8 +233,10 @@ async def test_similarity_search_with_relevance_score_with_threshold_and_filter(
@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_filters_with_qdrant_filters(
vector_name: Optional[str],
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@@ -231,6 +249,7 @@ async def test_qdrant_similarity_search_filters_with_qdrant_filters(
ConsistentFakeEmbeddings(),
metadatas=metadatas,
vector_name=vector_name,
location=qdrant_location,
)
qdrant_filter = rest.Filter(
@@ -263,11 +282,13 @@ async def test_qdrant_similarity_search_filters_with_qdrant_filters(
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_similarity_search_with_relevance_scores(
batch_size: int,
content_payload_key: str,
metadata_payload_key: str,
vector_name: str,
qdrant_location: str,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
@@ -278,6 +299,7 @@ async def test_qdrant_similarity_search_with_relevance_scores(
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
vector_name=vector_name,
location=qdrant_location,
)
output = await docsearch.asimilarity_search_with_relevance_scores("foo", k=3)

View File

@@ -13,7 +13,6 @@ from tests.integration_tests.vectorstores.fake_embeddings import (
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"])
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"])
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.skip(reason="Qdrant local behaves differently from Qdrant server")
def test_qdrant_max_marginal_relevance_search(
batch_size: int,
content_payload_key: str,
@@ -32,8 +31,11 @@ def test_qdrant_max_marginal_relevance_search(
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
vector_name=vector_name,
distance_func="EUCLID", # Euclid distance used to avoid normalization
)
output = docsearch.max_marginal_relevance_search(
"foo", k=2, fetch_k=3, lambda_mult=0.0
)
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
assert output == [
Document(page_content="foo", metadata={"page": 0}),
Document(page_content="baz", metadata={"page": 2}),