mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-04 22:23:50 +00:00
Implement async support in Qdrant local mode (#8001)
I've extended the support of async API to local Qdrant mode. It is faked but allows prototyping without spinning a container. The tests are improved to test the in-memory case as well. @baskaryan @rlancemartin @eyurtsev @agola11
This commit is contained in:
parent
7717c24fc4
commit
ed6a5532ac
@ -114,7 +114,6 @@
|
|||||||
" \"rating\": 9.9,\n",
|
" \"rating\": 9.9,\n",
|
||||||
" \"director\": \"Andrei Tarkovsky\",\n",
|
" \"director\": \"Andrei Tarkovsky\",\n",
|
||||||
" \"genre\": \"science fiction\",\n",
|
" \"genre\": \"science fiction\",\n",
|
||||||
" \"rating\": 9.9,\n",
|
|
||||||
" },\n",
|
" },\n",
|
||||||
" ),\n",
|
" ),\n",
|
||||||
"]\n",
|
"]\n",
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
"""Wrapper around Qdrant vector database."""
|
"""Wrapper around Qdrant vector database."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
@ -40,6 +42,30 @@ class QdrantException(Exception):
|
|||||||
"""Base class for all the Qdrant related exceptions"""
|
"""Base class for all the Qdrant related exceptions"""
|
||||||
|
|
||||||
|
|
||||||
|
def sync_call_fallback(method: Callable) -> Callable:
|
||||||
|
"""
|
||||||
|
Decorator to call the synchronous method of the class if the async method is not
|
||||||
|
implemented. This decorator might be only used for the methods that are defined
|
||||||
|
as async in the class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(method)
|
||||||
|
async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
try:
|
||||||
|
return await method(self, *args, **kwargs)
|
||||||
|
except NotImplementedError:
|
||||||
|
# If the async method is not implemented, call the synchronous method
|
||||||
|
# by removing the first letter from the method name. For example,
|
||||||
|
# if the async method is called ``aaad_texts``, the synchronous method
|
||||||
|
# will be called ``aad_texts``.
|
||||||
|
sync_method = functools.partial(
|
||||||
|
getattr(self, method.__name__[1:]), *args, **kwargs
|
||||||
|
)
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(None, sync_method)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class Qdrant(VectorStore):
|
class Qdrant(VectorStore):
|
||||||
"""Wrapper around Qdrant vector database.
|
"""Wrapper around Qdrant vector database.
|
||||||
|
|
||||||
@ -155,6 +181,7 @@ class Qdrant(VectorStore):
|
|||||||
|
|
||||||
return added_ids
|
return added_ids
|
||||||
|
|
||||||
|
@sync_call_fallback
|
||||||
async def aadd_texts(
|
async def aadd_texts(
|
||||||
self,
|
self,
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
@ -250,6 +277,7 @@ class Qdrant(VectorStore):
|
|||||||
)
|
)
|
||||||
return list(map(itemgetter(0), results))
|
return list(map(itemgetter(0), results))
|
||||||
|
|
||||||
|
@sync_call_fallback
|
||||||
async def asimilarity_search(
|
async def asimilarity_search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@ -322,6 +350,7 @@ class Qdrant(VectorStore):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@sync_call_fallback
|
||||||
async def asimilarity_search_with_score(
|
async def asimilarity_search_with_score(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@ -431,6 +460,7 @@ class Qdrant(VectorStore):
|
|||||||
)
|
)
|
||||||
return list(map(itemgetter(0), results))
|
return list(map(itemgetter(0), results))
|
||||||
|
|
||||||
|
@sync_call_fallback
|
||||||
async def asimilarity_search_by_vector(
|
async def asimilarity_search_by_vector(
|
||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
@ -567,6 +597,7 @@ class Qdrant(VectorStore):
|
|||||||
for result in results
|
for result in results
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@sync_call_fallback
|
||||||
async def asimilarity_search_with_score_by_vector(
|
async def asimilarity_search_with_score_by_vector(
|
||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
@ -685,6 +716,7 @@ class Qdrant(VectorStore):
|
|||||||
query_embedding, k, fetch_k, lambda_mult, **kwargs
|
query_embedding, k, fetch_k, lambda_mult, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@sync_call_fallback
|
||||||
async def amax_marginal_relevance_search(
|
async def amax_marginal_relevance_search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@ -739,33 +771,12 @@ class Qdrant(VectorStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List of Documents selected by maximal marginal relevance.
|
List of Documents selected by maximal marginal relevance.
|
||||||
"""
|
"""
|
||||||
query_vector = embedding
|
results = self.max_marginal_relevance_search_with_score_by_vector(
|
||||||
if self.vector_name is not None:
|
embedding=embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
|
||||||
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
|
|
||||||
|
|
||||||
results = self.client.search(
|
|
||||||
collection_name=self.collection_name,
|
|
||||||
query_vector=query_vector,
|
|
||||||
with_payload=True,
|
|
||||||
with_vectors=True,
|
|
||||||
limit=fetch_k,
|
|
||||||
)
|
)
|
||||||
embeddings = [
|
return list(map(itemgetter(0), results))
|
||||||
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
|
|
||||||
if self.vector_name is not None
|
|
||||||
else result.vector
|
|
||||||
for result in results
|
|
||||||
]
|
|
||||||
mmr_selected = maximal_marginal_relevance(
|
|
||||||
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
self._document_from_scored_point(
|
|
||||||
results[i], self.content_payload_key, self.metadata_payload_key
|
|
||||||
)
|
|
||||||
for i in mmr_selected
|
|
||||||
]
|
|
||||||
|
|
||||||
|
@sync_call_fallback
|
||||||
async def amax_marginal_relevance_search_by_vector(
|
async def amax_marginal_relevance_search_by_vector(
|
||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
@ -795,6 +806,61 @@ class Qdrant(VectorStore):
|
|||||||
)
|
)
|
||||||
return list(map(itemgetter(0), results))
|
return list(map(itemgetter(0), results))
|
||||||
|
|
||||||
|
def max_marginal_relevance_search_with_score_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
Defaults to 20.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance and distance for
|
||||||
|
each.
|
||||||
|
"""
|
||||||
|
query_vector = embedding
|
||||||
|
if self.vector_name is not None:
|
||||||
|
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
|
||||||
|
|
||||||
|
results = self.client.search(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
query_vector=query_vector,
|
||||||
|
with_payload=True,
|
||||||
|
with_vectors=True,
|
||||||
|
limit=fetch_k,
|
||||||
|
)
|
||||||
|
embeddings = [
|
||||||
|
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
|
||||||
|
if self.vector_name is not None
|
||||||
|
else result.vector
|
||||||
|
for result in results
|
||||||
|
]
|
||||||
|
mmr_selected = maximal_marginal_relevance(
|
||||||
|
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
(
|
||||||
|
self._document_from_scored_point(
|
||||||
|
results[i], self.content_payload_key, self.metadata_payload_key
|
||||||
|
),
|
||||||
|
results[i].score,
|
||||||
|
)
|
||||||
|
for i in mmr_selected
|
||||||
|
]
|
||||||
|
|
||||||
|
@sync_call_fallback
|
||||||
async def amax_marginal_relevance_search_with_score_by_vector(
|
async def amax_marginal_relevance_search_with_score_by_vector(
|
||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
@ -1038,7 +1104,6 @@ class Qdrant(VectorStore):
|
|||||||
content_payload_key,
|
content_payload_key,
|
||||||
metadata_payload_key,
|
metadata_payload_key,
|
||||||
vector_name,
|
vector_name,
|
||||||
batch_size,
|
|
||||||
shard_number,
|
shard_number,
|
||||||
replication_factor,
|
replication_factor,
|
||||||
write_consistency_factor,
|
write_consistency_factor,
|
||||||
@ -1055,6 +1120,7 @@ class Qdrant(VectorStore):
|
|||||||
return qdrant
|
return qdrant
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@sync_call_fallback
|
||||||
async def afrom_texts(
|
async def afrom_texts(
|
||||||
cls: Type[Qdrant],
|
cls: Type[Qdrant],
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
@ -1214,7 +1280,6 @@ class Qdrant(VectorStore):
|
|||||||
content_payload_key,
|
content_payload_key,
|
||||||
metadata_payload_key,
|
metadata_payload_key,
|
||||||
vector_name,
|
vector_name,
|
||||||
batch_size,
|
|
||||||
shard_number,
|
shard_number,
|
||||||
replication_factor,
|
replication_factor,
|
||||||
write_consistency_factor,
|
write_consistency_factor,
|
||||||
@ -1253,7 +1318,6 @@ class Qdrant(VectorStore):
|
|||||||
content_payload_key: str = CONTENT_KEY,
|
content_payload_key: str = CONTENT_KEY,
|
||||||
metadata_payload_key: str = METADATA_KEY,
|
metadata_payload_key: str = METADATA_KEY,
|
||||||
vector_name: Optional[str] = VECTOR_NAME,
|
vector_name: Optional[str] = VECTOR_NAME,
|
||||||
batch_size: int = 64,
|
|
||||||
shard_number: Optional[int] = None,
|
shard_number: Optional[int] = None,
|
||||||
replication_factor: Optional[int] = None,
|
replication_factor: Optional[int] = None,
|
||||||
write_consistency_factor: Optional[int] = None,
|
write_consistency_factor: Optional[int] = None,
|
||||||
|
@ -0,0 +1,13 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from tests.integration_tests.vectorstores.qdrant.common import qdrant_is_not_running
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def qdrant_locations() -> List[str]:
|
||||||
|
if qdrant_is_not_running():
|
||||||
|
logger.warning("Running Qdrant async tests in memory mode only.")
|
||||||
|
return [":memory:"]
|
||||||
|
return ["http://localhost:6333", ":memory:"]
|
@ -7,23 +7,23 @@ from langchain.vectorstores import Qdrant
|
|||||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||||
ConsistentFakeEmbeddings,
|
ConsistentFakeEmbeddings,
|
||||||
)
|
)
|
||||||
|
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import ( # noqa
|
||||||
from .common import qdrant_is_not_running
|
qdrant_locations,
|
||||||
|
|
||||||
# Skipping all the tests in the module if Qdrant is not running on localhost.
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
qdrant_is_not_running(), reason="Qdrant server is not running"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||||
async def test_qdrant_aadd_texts_returns_all_ids(batch_size: int) -> None:
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
|
async def test_qdrant_aadd_texts_returns_all_ids(
|
||||||
|
batch_size: int, qdrant_location: str
|
||||||
|
) -> None:
|
||||||
"""Test end to end Qdrant.aadd_texts returns unique ids."""
|
"""Test end to end Qdrant.aadd_texts returns unique ids."""
|
||||||
docsearch: Qdrant = Qdrant.from_texts(
|
docsearch: Qdrant = Qdrant.from_texts(
|
||||||
["foobar"],
|
["foobar"],
|
||||||
ConsistentFakeEmbeddings(),
|
ConsistentFakeEmbeddings(),
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
location=qdrant_location,
|
||||||
)
|
)
|
||||||
|
|
||||||
ids = await docsearch.aadd_texts(["foo", "bar", "baz"])
|
ids = await docsearch.aadd_texts(["foo", "bar", "baz"])
|
||||||
@ -33,14 +33,15 @@ async def test_qdrant_aadd_texts_returns_all_ids(batch_size: int) -> None:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||||
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
async def test_qdrant_aadd_texts_stores_duplicated_texts(
|
async def test_qdrant_aadd_texts_stores_duplicated_texts(
|
||||||
vector_name: Optional[str],
|
vector_name: Optional[str], qdrant_location: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test end to end Qdrant.aadd_texts stores duplicated texts separately."""
|
"""Test end to end Qdrant.aadd_texts stores duplicated texts separately."""
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
from qdrant_client.http import models as rest
|
from qdrant_client.http import models as rest
|
||||||
|
|
||||||
client = QdrantClient()
|
client = QdrantClient(location=qdrant_location)
|
||||||
collection_name = "test"
|
collection_name = "test"
|
||||||
vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE)
|
vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE)
|
||||||
if vector_name is not None:
|
if vector_name is not None:
|
||||||
@ -61,7 +62,10 @@ async def test_qdrant_aadd_texts_stores_duplicated_texts(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||||
async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None:
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
|
async def test_qdrant_aadd_texts_stores_ids(
|
||||||
|
batch_size: int, qdrant_location: str
|
||||||
|
) -> None:
|
||||||
"""Test end to end Qdrant.aadd_texts stores provided ids."""
|
"""Test end to end Qdrant.aadd_texts stores provided ids."""
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
|
|
||||||
@ -70,7 +74,7 @@ async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None:
|
|||||||
"cdc1aa36-d6ab-4fb2-8a94-56674fd27484",
|
"cdc1aa36-d6ab-4fb2-8a94-56674fd27484",
|
||||||
]
|
]
|
||||||
|
|
||||||
client = QdrantClient()
|
client = QdrantClient(location=qdrant_location)
|
||||||
collection_name = "test"
|
collection_name = "test"
|
||||||
client.recreate_collection(
|
client.recreate_collection(
|
||||||
collection_name,
|
collection_name,
|
||||||
@ -90,15 +94,16 @@ async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("vector_name", ["custom-vector"])
|
@pytest.mark.parametrize("vector_name", ["custom-vector"])
|
||||||
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
async def test_qdrant_aadd_texts_stores_embeddings_as_named_vectors(
|
async def test_qdrant_aadd_texts_stores_embeddings_as_named_vectors(
|
||||||
vector_name: str,
|
vector_name: str, qdrant_location: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test end to end Qdrant.aadd_texts stores named vectors if name is provided."""
|
"""Test end to end Qdrant.aadd_texts stores named vectors if name is provided."""
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
|
|
||||||
collection_name = "test"
|
collection_name = "test"
|
||||||
|
|
||||||
client = QdrantClient()
|
client = QdrantClient(location=qdrant_location)
|
||||||
client.recreate_collection(
|
client.recreate_collection(
|
||||||
collection_name,
|
collection_name,
|
||||||
vectors_config={
|
vectors_config={
|
@ -9,56 +9,53 @@ from langchain.vectorstores.qdrant import QdrantException
|
|||||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||||
ConsistentFakeEmbeddings,
|
ConsistentFakeEmbeddings,
|
||||||
)
|
)
|
||||||
|
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import (
|
||||||
from .common import qdrant_is_not_running
|
qdrant_locations,
|
||||||
|
|
||||||
# Skipping all the tests in the module if Qdrant is not running on localhost.
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
qdrant_is_not_running(), reason="Qdrant server is not running"
|
|
||||||
)
|
)
|
||||||
|
from tests.integration_tests.vectorstores.qdrant.common import qdrant_is_not_running
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_qdrant_from_texts_stores_duplicated_texts() -> None:
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
|
async def test_qdrant_from_texts_stores_duplicated_texts(qdrant_location: str) -> None:
|
||||||
"""Test end to end Qdrant.afrom_texts stores duplicated texts separately."""
|
"""Test end to end Qdrant.afrom_texts stores duplicated texts separately."""
|
||||||
from qdrant_client import QdrantClient
|
|
||||||
|
|
||||||
collection_name = uuid.uuid4().hex
|
collection_name = uuid.uuid4().hex
|
||||||
|
|
||||||
await Qdrant.afrom_texts(
|
vec_store = await Qdrant.afrom_texts(
|
||||||
["abc", "abc"],
|
["abc", "abc"],
|
||||||
ConsistentFakeEmbeddings(),
|
ConsistentFakeEmbeddings(),
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
|
location=qdrant_location,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = QdrantClient()
|
client = vec_store.client
|
||||||
assert 2 == client.count(collection_name).count
|
assert 2 == client.count(collection_name).count
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||||
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
async def test_qdrant_from_texts_stores_ids(
|
async def test_qdrant_from_texts_stores_ids(
|
||||||
batch_size: int, vector_name: Optional[str]
|
batch_size: int, vector_name: Optional[str], qdrant_location: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test end to end Qdrant.afrom_texts stores provided ids."""
|
"""Test end to end Qdrant.afrom_texts stores provided ids."""
|
||||||
from qdrant_client import QdrantClient
|
|
||||||
|
|
||||||
collection_name = uuid.uuid4().hex
|
collection_name = uuid.uuid4().hex
|
||||||
ids = [
|
ids = [
|
||||||
"fa38d572-4c31-4579-aedc-1960d79df6df",
|
"fa38d572-4c31-4579-aedc-1960d79df6df",
|
||||||
"cdc1aa36-d6ab-4fb2-8a94-56674fd27484",
|
"cdc1aa36-d6ab-4fb2-8a94-56674fd27484",
|
||||||
]
|
]
|
||||||
await Qdrant.afrom_texts(
|
vec_store = await Qdrant.afrom_texts(
|
||||||
["abc", "def"],
|
["abc", "def"],
|
||||||
ConsistentFakeEmbeddings(),
|
ConsistentFakeEmbeddings(),
|
||||||
ids=ids,
|
ids=ids,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
|
location=qdrant_location,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = QdrantClient()
|
client = vec_store.client
|
||||||
assert 2 == client.count(collection_name).count
|
assert 2 == client.count(collection_name).count
|
||||||
stored_ids = [point.id for point in client.scroll(collection_name)[0]]
|
stored_ids = [point.id for point in client.scroll(collection_name)[0]]
|
||||||
assert set(ids) == set(stored_ids)
|
assert set(ids) == set(stored_ids)
|
||||||
@ -66,22 +63,23 @@ async def test_qdrant_from_texts_stores_ids(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("vector_name", ["custom-vector"])
|
@pytest.mark.parametrize("vector_name", ["custom-vector"])
|
||||||
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
async def test_qdrant_from_texts_stores_embeddings_as_named_vectors(
|
async def test_qdrant_from_texts_stores_embeddings_as_named_vectors(
|
||||||
vector_name: str,
|
vector_name: str,
|
||||||
|
qdrant_location: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test end to end Qdrant.afrom_texts stores named vectors if name is provided."""
|
"""Test end to end Qdrant.afrom_texts stores named vectors if name is provided."""
|
||||||
from qdrant_client import QdrantClient
|
|
||||||
|
|
||||||
collection_name = uuid.uuid4().hex
|
collection_name = uuid.uuid4().hex
|
||||||
|
|
||||||
await Qdrant.afrom_texts(
|
vec_store = await Qdrant.afrom_texts(
|
||||||
["lorem", "ipsum", "dolor", "sit", "amet"],
|
["lorem", "ipsum", "dolor", "sit", "amet"],
|
||||||
ConsistentFakeEmbeddings(),
|
ConsistentFakeEmbeddings(),
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
|
location=qdrant_location,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = QdrantClient()
|
client = vec_store.client
|
||||||
assert 5 == client.count(collection_name).count
|
assert 5 == client.count(collection_name).count
|
||||||
assert all(
|
assert all(
|
||||||
vector_name in point.vector # type: ignore[operator]
|
vector_name in point.vector # type: ignore[operator]
|
||||||
@ -91,12 +89,11 @@ async def test_qdrant_from_texts_stores_embeddings_as_named_vectors(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
|
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
|
||||||
|
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
|
||||||
async def test_qdrant_from_texts_reuses_same_collection(
|
async def test_qdrant_from_texts_reuses_same_collection(
|
||||||
vector_name: Optional[str],
|
vector_name: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test if Qdrant.afrom_texts reuses the same collection"""
|
"""Test if Qdrant.afrom_texts reuses the same collection"""
|
||||||
from qdrant_client import QdrantClient
|
|
||||||
|
|
||||||
collection_name = uuid.uuid4().hex
|
collection_name = uuid.uuid4().hex
|
||||||
embeddings = ConsistentFakeEmbeddings()
|
embeddings = ConsistentFakeEmbeddings()
|
||||||
|
|
||||||
@ -107,19 +104,20 @@ async def test_qdrant_from_texts_reuses_same_collection(
|
|||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
await Qdrant.afrom_texts(
|
vec_store = await Qdrant.afrom_texts(
|
||||||
["foo", "bar"],
|
["foo", "bar"],
|
||||||
embeddings,
|
embeddings,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = QdrantClient()
|
client = vec_store.client
|
||||||
assert 7 == client.count(collection_name).count
|
assert 7 == client.count(collection_name).count
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
|
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
|
||||||
|
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
|
||||||
async def test_qdrant_from_texts_raises_error_on_different_dimensionality(
|
async def test_qdrant_from_texts_raises_error_on_different_dimensionality(
|
||||||
vector_name: Optional[str],
|
vector_name: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -152,6 +150,7 @@ async def test_qdrant_from_texts_raises_error_on_different_dimensionality(
|
|||||||
("my-first-vector", "my-second_vector"),
|
("my-first-vector", "my-second_vector"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
|
||||||
async def test_qdrant_from_texts_raises_error_on_different_vector_name(
|
async def test_qdrant_from_texts_raises_error_on_different_vector_name(
|
||||||
first_vector_name: Optional[str],
|
first_vector_name: Optional[str],
|
||||||
second_vector_name: Optional[str],
|
second_vector_name: Optional[str],
|
||||||
@ -176,6 +175,7 @@ async def test_qdrant_from_texts_raises_error_on_different_vector_name(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
|
||||||
async def test_qdrant_from_texts_raises_error_on_different_distance() -> None:
|
async def test_qdrant_from_texts_raises_error_on_different_distance() -> None:
|
||||||
"""Test if Qdrant.afrom_texts raises an exception if distance does not match"""
|
"""Test if Qdrant.afrom_texts raises an exception if distance does not match"""
|
||||||
collection_name = uuid.uuid4().hex
|
collection_name = uuid.uuid4().hex
|
||||||
@ -198,6 +198,7 @@ async def test_qdrant_from_texts_raises_error_on_different_distance() -> None:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
|
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
|
||||||
|
@pytest.mark.skipif(qdrant_is_not_running(), reason="Qdrant is not running")
|
||||||
async def test_qdrant_from_texts_recreates_collection_on_force_recreate(
|
async def test_qdrant_from_texts_recreates_collection_on_force_recreate(
|
||||||
vector_name: Optional[str],
|
vector_name: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -229,8 +230,12 @@ async def test_qdrant_from_texts_recreates_collection_on_force_recreate(
|
|||||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
||||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
||||||
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
async def test_qdrant_from_texts_stores_metadatas(
|
async def test_qdrant_from_texts_stores_metadatas(
|
||||||
batch_size: int, content_payload_key: str, metadata_payload_key: str
|
batch_size: int,
|
||||||
|
content_payload_key: str,
|
||||||
|
metadata_payload_key: str,
|
||||||
|
qdrant_location: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
@ -242,6 +247,7 @@ async def test_qdrant_from_texts_stores_metadatas(
|
|||||||
content_payload_key=content_payload_key,
|
content_payload_key=content_payload_key,
|
||||||
metadata_payload_key=metadata_payload_key,
|
metadata_payload_key=metadata_payload_key,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
location=qdrant_location,
|
||||||
)
|
)
|
||||||
output = await docsearch.asimilarity_search("foo", k=1)
|
output = await docsearch.asimilarity_search("foo", k=1)
|
||||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
@ -7,12 +7,8 @@ from langchain.vectorstores import Qdrant
|
|||||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||||
ConsistentFakeEmbeddings,
|
ConsistentFakeEmbeddings,
|
||||||
)
|
)
|
||||||
|
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import (
|
||||||
from .common import qdrant_is_not_running
|
qdrant_locations,
|
||||||
|
|
||||||
# Skipping all the tests in the module if Qdrant is not running on localhost.
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
qdrant_is_not_running(), reason="Qdrant server is not running"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -21,11 +17,13 @@ pytestmark = pytest.mark.skipif(
|
|||||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"])
|
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"])
|
||||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"])
|
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"])
|
||||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||||
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
async def test_qdrant_max_marginal_relevance_search(
|
async def test_qdrant_max_marginal_relevance_search(
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
content_payload_key: str,
|
content_payload_key: str,
|
||||||
metadata_payload_key: str,
|
metadata_payload_key: str,
|
||||||
vector_name: Optional[str],
|
vector_name: Optional[str],
|
||||||
|
qdrant_location: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test end to end construction and MRR search."""
|
"""Test end to end construction and MRR search."""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
@ -38,8 +36,12 @@ async def test_qdrant_max_marginal_relevance_search(
|
|||||||
metadata_payload_key=metadata_payload_key,
|
metadata_payload_key=metadata_payload_key,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
|
location=qdrant_location,
|
||||||
|
distance_func="EUCLID", # Euclid distance used to avoid normalization
|
||||||
|
)
|
||||||
|
output = await docsearch.amax_marginal_relevance_search(
|
||||||
|
"foo", k=2, fetch_k=3, lambda_mult=0.0
|
||||||
)
|
)
|
||||||
output = await docsearch.amax_marginal_relevance_search("foo", k=2, fetch_k=3)
|
|
||||||
assert output == [
|
assert output == [
|
||||||
Document(page_content="foo", metadata={"page": 0}),
|
Document(page_content="foo", metadata={"page": 0}),
|
||||||
Document(page_content="baz", metadata={"page": 2}),
|
Document(page_content="baz", metadata={"page": 2}),
|
@ -9,12 +9,8 @@ from langchain.vectorstores import Qdrant
|
|||||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||||
ConsistentFakeEmbeddings,
|
ConsistentFakeEmbeddings,
|
||||||
)
|
)
|
||||||
|
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import (
|
||||||
from .common import qdrant_is_not_running
|
qdrant_locations,
|
||||||
|
|
||||||
# Skipping all the tests in the module if Qdrant is not running on localhost.
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
qdrant_is_not_running(), reason="Qdrant server is not running"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -23,11 +19,13 @@ pytestmark = pytest.mark.skipif(
|
|||||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
||||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
||||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||||
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
async def test_qdrant_similarity_search(
|
async def test_qdrant_similarity_search(
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
content_payload_key: str,
|
content_payload_key: str,
|
||||||
metadata_payload_key: str,
|
metadata_payload_key: str,
|
||||||
vector_name: Optional[str],
|
vector_name: Optional[str],
|
||||||
|
qdrant_location: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
@ -38,6 +36,7 @@ async def test_qdrant_similarity_search(
|
|||||||
metadata_payload_key=metadata_payload_key,
|
metadata_payload_key=metadata_payload_key,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
|
location=qdrant_location,
|
||||||
)
|
)
|
||||||
output = await docsearch.asimilarity_search("foo", k=1)
|
output = await docsearch.asimilarity_search("foo", k=1)
|
||||||
assert output == [Document(page_content="foo")]
|
assert output == [Document(page_content="foo")]
|
||||||
@ -48,11 +47,13 @@ async def test_qdrant_similarity_search(
|
|||||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
||||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
||||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||||
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
async def test_qdrant_similarity_search_by_vector(
|
async def test_qdrant_similarity_search_by_vector(
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
content_payload_key: str,
|
content_payload_key: str,
|
||||||
metadata_payload_key: str,
|
metadata_payload_key: str,
|
||||||
vector_name: Optional[str],
|
vector_name: Optional[str],
|
||||||
|
qdrant_location: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
@ -63,6 +64,7 @@ async def test_qdrant_similarity_search_by_vector(
|
|||||||
metadata_payload_key=metadata_payload_key,
|
metadata_payload_key=metadata_payload_key,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
|
location=qdrant_location,
|
||||||
)
|
)
|
||||||
embeddings = ConsistentFakeEmbeddings().embed_query("foo")
|
embeddings = ConsistentFakeEmbeddings().embed_query("foo")
|
||||||
output = await docsearch.asimilarity_search_by_vector(embeddings, k=1)
|
output = await docsearch.asimilarity_search_by_vector(embeddings, k=1)
|
||||||
@ -74,11 +76,13 @@ async def test_qdrant_similarity_search_by_vector(
|
|||||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
||||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
||||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||||
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
async def test_qdrant_similarity_search_with_score_by_vector(
|
async def test_qdrant_similarity_search_with_score_by_vector(
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
content_payload_key: str,
|
content_payload_key: str,
|
||||||
metadata_payload_key: str,
|
metadata_payload_key: str,
|
||||||
vector_name: Optional[str],
|
vector_name: Optional[str],
|
||||||
|
qdrant_location: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
@ -89,6 +93,7 @@ async def test_qdrant_similarity_search_with_score_by_vector(
|
|||||||
metadata_payload_key=metadata_payload_key,
|
metadata_payload_key=metadata_payload_key,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
|
location=qdrant_location,
|
||||||
)
|
)
|
||||||
embeddings = ConsistentFakeEmbeddings().embed_query("foo")
|
embeddings = ConsistentFakeEmbeddings().embed_query("foo")
|
||||||
output = await docsearch.asimilarity_search_with_score_by_vector(embeddings, k=1)
|
output = await docsearch.asimilarity_search_with_score_by_vector(embeddings, k=1)
|
||||||
@ -101,8 +106,9 @@ async def test_qdrant_similarity_search_with_score_by_vector(
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||||
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
async def test_qdrant_similarity_search_filters(
|
async def test_qdrant_similarity_search_filters(
|
||||||
batch_size: int, vector_name: Optional[str]
|
batch_size: int, vector_name: Optional[str], qdrant_location: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
@ -116,6 +122,7 @@ async def test_qdrant_similarity_search_filters(
|
|||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
|
location=qdrant_location,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = await docsearch.asimilarity_search(
|
output = await docsearch.asimilarity_search(
|
||||||
@ -131,8 +138,10 @@ async def test_qdrant_similarity_search_filters(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||||
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
async def test_qdrant_similarity_search_with_relevance_score_no_threshold(
|
async def test_qdrant_similarity_search_with_relevance_score_no_threshold(
|
||||||
vector_name: Optional[str],
|
vector_name: Optional[str],
|
||||||
|
qdrant_location: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
@ -145,6 +154,7 @@ async def test_qdrant_similarity_search_with_relevance_score_no_threshold(
|
|||||||
ConsistentFakeEmbeddings(),
|
ConsistentFakeEmbeddings(),
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
|
location=qdrant_location,
|
||||||
)
|
)
|
||||||
output = await docsearch.asimilarity_search_with_relevance_scores(
|
output = await docsearch.asimilarity_search_with_relevance_scores(
|
||||||
"foo", k=3, score_threshold=None
|
"foo", k=3, score_threshold=None
|
||||||
@ -157,8 +167,10 @@ async def test_qdrant_similarity_search_with_relevance_score_no_threshold(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||||
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
async def test_qdrant_similarity_search_with_relevance_score_with_threshold(
|
async def test_qdrant_similarity_search_with_relevance_score_with_threshold(
|
||||||
vector_name: Optional[str],
|
vector_name: Optional[str],
|
||||||
|
qdrant_location: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
@ -171,6 +183,7 @@ async def test_qdrant_similarity_search_with_relevance_score_with_threshold(
|
|||||||
ConsistentFakeEmbeddings(),
|
ConsistentFakeEmbeddings(),
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
|
location=qdrant_location,
|
||||||
)
|
)
|
||||||
|
|
||||||
score_threshold = 0.98
|
score_threshold = 0.98
|
||||||
@ -184,8 +197,10 @@ async def test_qdrant_similarity_search_with_relevance_score_with_threshold(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||||
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
async def test_similarity_search_with_relevance_score_with_threshold_and_filter(
|
async def test_similarity_search_with_relevance_score_with_threshold_and_filter(
|
||||||
vector_name: Optional[str],
|
vector_name: Optional[str],
|
||||||
|
qdrant_location: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
@ -198,6 +213,7 @@ async def test_similarity_search_with_relevance_score_with_threshold_and_filter(
|
|||||||
ConsistentFakeEmbeddings(),
|
ConsistentFakeEmbeddings(),
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
|
location=qdrant_location,
|
||||||
)
|
)
|
||||||
score_threshold = 0.99 # for almost exact match
|
score_threshold = 0.99 # for almost exact match
|
||||||
# test negative filter condition
|
# test negative filter condition
|
||||||
@ -217,8 +233,10 @@ async def test_similarity_search_with_relevance_score_with_threshold_and_filter(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||||
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
async def test_qdrant_similarity_search_filters_with_qdrant_filters(
|
async def test_qdrant_similarity_search_filters_with_qdrant_filters(
|
||||||
vector_name: Optional[str],
|
vector_name: Optional[str],
|
||||||
|
qdrant_location: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
@ -231,6 +249,7 @@ async def test_qdrant_similarity_search_filters_with_qdrant_filters(
|
|||||||
ConsistentFakeEmbeddings(),
|
ConsistentFakeEmbeddings(),
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
|
location=qdrant_location,
|
||||||
)
|
)
|
||||||
|
|
||||||
qdrant_filter = rest.Filter(
|
qdrant_filter = rest.Filter(
|
||||||
@ -263,11 +282,13 @@ async def test_qdrant_similarity_search_filters_with_qdrant_filters(
|
|||||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
||||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
||||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||||
|
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
|
||||||
async def test_qdrant_similarity_search_with_relevance_scores(
|
async def test_qdrant_similarity_search_with_relevance_scores(
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
content_payload_key: str,
|
content_payload_key: str,
|
||||||
metadata_payload_key: str,
|
metadata_payload_key: str,
|
||||||
vector_name: str,
|
vector_name: str,
|
||||||
|
qdrant_location: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
@ -278,6 +299,7 @@ async def test_qdrant_similarity_search_with_relevance_scores(
|
|||||||
metadata_payload_key=metadata_payload_key,
|
metadata_payload_key=metadata_payload_key,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
|
location=qdrant_location,
|
||||||
)
|
)
|
||||||
output = await docsearch.asimilarity_search_with_relevance_scores("foo", k=3)
|
output = await docsearch.asimilarity_search_with_relevance_scores("foo", k=3)
|
||||||
|
|
@ -13,7 +13,6 @@ from tests.integration_tests.vectorstores.fake_embeddings import (
|
|||||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"])
|
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"])
|
||||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"])
|
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"])
|
||||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||||
@pytest.mark.skip(reason="Qdrant local behaves differently from Qdrant server")
|
|
||||||
def test_qdrant_max_marginal_relevance_search(
|
def test_qdrant_max_marginal_relevance_search(
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
content_payload_key: str,
|
content_payload_key: str,
|
||||||
@ -32,8 +31,11 @@ def test_qdrant_max_marginal_relevance_search(
|
|||||||
metadata_payload_key=metadata_payload_key,
|
metadata_payload_key=metadata_payload_key,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
|
distance_func="EUCLID", # Euclid distance used to avoid normalization
|
||||||
|
)
|
||||||
|
output = docsearch.max_marginal_relevance_search(
|
||||||
|
"foo", k=2, fetch_k=3, lambda_mult=0.0
|
||||||
)
|
)
|
||||||
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
|
|
||||||
assert output == [
|
assert output == [
|
||||||
Document(page_content="foo", metadata={"page": 0}),
|
Document(page_content="foo", metadata={"page": 0}),
|
||||||
Document(page_content="baz", metadata={"page": 2}),
|
Document(page_content="baz", metadata={"page": 2}),
|
||||||
|
Loading…
Reference in New Issue
Block a user