Reuse the existing collection if configured properly in Qdrant.from_texts (#7530)

This PR changes the behavior of `Qdrant.from_texts` so the collection is
reused if not requested to recreate it. Previously, calling
`Qdrant.from_texts` or `Qdrant.from_documents` resulted in removing the
old data which was confusing for many.
This commit is contained in:
Kacper Łukawski 2023-07-11 22:24:35 +02:00 committed by GitHub
parent 6674b33cf5
commit 1f83b5f47e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 989 additions and 804 deletions

View File

@ -34,6 +34,10 @@ if TYPE_CHECKING:
MetadataFilter = Union[DictFilter, common_types.Filter] MetadataFilter = Union[DictFilter, common_types.Filter]
class QdrantException(Exception):
"""Base class for all the Qdrant related exceptions"""
class Qdrant(VectorStore): class Qdrant(VectorStore):
"""Wrapper around Qdrant vector database. """Wrapper around Qdrant vector database.
@ -552,6 +556,7 @@ class Qdrant(VectorStore):
wal_config: Optional[common_types.WalConfigDiff] = None, wal_config: Optional[common_types.WalConfigDiff] = None,
quantization_config: Optional[common_types.QuantizationConfig] = None, quantization_config: Optional[common_types.QuantizationConfig] = None,
init_from: Optional[common_types.InitFrom] = None, init_from: Optional[common_types.InitFrom] = None,
force_recreate: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Qdrant: ) -> Qdrant:
"""Construct Qdrant wrapper from a list of texts. """Construct Qdrant wrapper from a list of texts.
@ -636,6 +641,8 @@ class Qdrant(VectorStore):
Params for quantization, if None - quantization will be disabled Params for quantization, if None - quantization will be disabled
init_from: init_from:
Use data stored in another collection to initialize this collection Use data stored in another collection to initialize this collection
force_recreate:
Force recreating the collection
**kwargs: **kwargs:
Additional arguments passed directly into REST client initialization Additional arguments passed directly into REST client initialization
@ -663,7 +670,9 @@ class Qdrant(VectorStore):
"Please install it with `pip install qdrant-client`." "Please install it with `pip install qdrant-client`."
) )
from grpc import RpcError
from qdrant_client.http import models as rest from qdrant_client.http import models as rest
from qdrant_client.http.exceptions import UnexpectedResponse
# Just do a single quick embedding to get vector size # Just do a single quick embedding to get vector size
partial_embeddings = embedding.embed_documents(texts[:1]) partial_embeddings = embedding.embed_documents(texts[:1])
@ -687,62 +696,98 @@ class Qdrant(VectorStore):
**kwargs, **kwargs,
) )
vectors_config = rest.VectorParams( try:
size=vector_size, # Skip any validation in case of forced collection recreate.
distance=rest.Distance[distance_func], if force_recreate:
) raise ValueError
# If vector name was provided, we're going to use the named vectors feature # Get the vector configuration of the existing collection and vector, if it
# with just a single vector. # was specified. If the old configuration does not match the current one,
if vector_name is not None: # an exception is being thrown.
vectors_config = { # type: ignore[assignment] collection_info = client.get_collection(collection_name=collection_name)
vector_name: vectors_config, current_vector_config = collection_info.config.params.vectors
} if isinstance(current_vector_config, dict) and vector_name is not None:
if vector_name not in current_vector_config:
raise QdrantException(
f"Existing Qdrant collection {collection_name} does not "
f"contain vector named {vector_name}. Did you mean one of the "
f"existing vectors: {', '.join(current_vector_config.keys())}? "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_vector_config = current_vector_config.get(
vector_name
) # type: ignore[assignment]
elif isinstance(current_vector_config, dict) and vector_name is None:
raise QdrantException(
f"Existing Qdrant collection {collection_name} uses named vectors. "
f"If you want to reuse it, please set `vector_name` to any of the "
f"existing named vectors: "
f"{', '.join(current_vector_config.keys())}." # noqa
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
elif (
not isinstance(current_vector_config, dict) and vector_name is not None
):
raise QdrantException(
f"Existing Qdrant collection {collection_name} doesn't use named "
f"vectors. If you want to reuse it, please set `vector_name` to "
f"`None`. If you want to recreate the collection, set "
f"`force_recreate` parameter to `True`."
)
client.recreate_collection( # Check if the vector configuration has the same dimensionality.
collection_name=collection_name, if current_vector_config.size != vector_size: # type: ignore[union-attr]
vectors_config=vectors_config, raise QdrantException(
shard_number=shard_number, f"Existing Qdrant collection is configured for vectors with "
replication_factor=replication_factor, f"{current_vector_config.size} " # type: ignore[union-attr]
write_consistency_factor=write_consistency_factor, f"dimensions. Selected embeddings are {vector_size}-dimensional. "
on_disk_payload=on_disk_payload, f"If you want to recreate the collection, set `force_recreate` "
hnsw_config=hnsw_config, f"parameter to `True`."
optimizers_config=optimizers_config, )
wal_config=wal_config,
quantization_config=quantization_config,
init_from=init_from,
timeout=timeout, # type: ignore[arg-type]
)
texts_iterator = iter(texts) current_distance_func = (
metadatas_iterator = iter(metadatas or []) current_vector_config.distance.name.upper() # type: ignore[union-attr]
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) )
while batch_texts := list(islice(texts_iterator, batch_size)): if current_distance_func != distance_func:
# Take the corresponding metadata and id for each text in a batch raise QdrantException(
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None f"Existing Qdrant collection is configured for "
batch_ids = list(islice(ids_iterator, batch_size)) f"{current_vector_config.distance} " # type: ignore[union-attr]
f"similarity. Please set `distance_func` parameter to "
# Generate the embeddings for all the texts in a batch f"`{distance_func}` if you want to reuse it. If you want to "
batch_embeddings = embedding.embed_documents(batch_texts) f"recreate the collection, set `force_recreate` parameter to "
if vector_name is not None: f"`True`."
batch_embeddings = { # type: ignore[assignment] )
vector_name: batch_embeddings except (UnexpectedResponse, RpcError, ValueError):
} vectors_config = rest.VectorParams(
size=vector_size,
points = rest.Batch.construct( distance=rest.Distance[distance_func],
ids=batch_ids,
vectors=batch_embeddings,
payloads=cls._build_payloads(
batch_texts,
batch_metadatas,
content_payload_key,
metadata_payload_key,
),
) )
client.upsert(collection_name=collection_name, points=points) # If vector name was provided, we're going to use the named vectors feature
# with just a single vector.
if vector_name is not None:
vectors_config = { # type: ignore[assignment]
vector_name: vectors_config,
}
return cls( client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
shard_number=shard_number,
replication_factor=replication_factor,
write_consistency_factor=write_consistency_factor,
on_disk_payload=on_disk_payload,
hnsw_config=hnsw_config,
optimizers_config=optimizers_config,
wal_config=wal_config,
quantization_config=quantization_config,
init_from=init_from,
timeout=timeout, # type: ignore[arg-type]
)
qdrant = cls(
client=client, client=client,
collection_name=collection_name, collection_name=collection_name,
embeddings=embedding, embeddings=embedding,
@ -751,6 +796,8 @@ class Qdrant(VectorStore):
distance_strategy=distance_func, distance_strategy=distance_func,
vector_name=vector_name, vector_name=vector_name,
) )
qdrant.add_texts(texts, metadatas, ids, batch_size)
return qdrant
@classmethod @classmethod
def _build_payloads( def _build_payloads(

View File

@ -27,8 +27,9 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
"""Fake embeddings which remember all the texts seen so far to return consistent """Fake embeddings which remember all the texts seen so far to return consistent
vectors for the same texts.""" vectors for the same texts."""
def __init__(self) -> None: def __init__(self, dimensionality: int = 10) -> None:
self.known_texts: List[str] = [] self.known_texts: List[str] = []
self.dimensionality = dimensionality
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return consistent embeddings for each text seen so far.""" """Return consistent embeddings for each text seen so far."""
@ -36,7 +37,9 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
for text in texts: for text in texts:
if text not in self.known_texts: if text not in self.known_texts:
self.known_texts.append(text) self.known_texts.append(text)
vector = [float(1.0)] * 9 + [float(self.known_texts.index(text))] vector = [float(1.0)] * (self.dimensionality - 1) + [
float(self.known_texts.index(text))
]
out_vectors.append(vector) out_vectors.append(vector)
return out_vectors return out_vectors
@ -44,8 +47,10 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
"""Return consistent embeddings for the text, if seen before, or a constant """Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown.""" one if the text is unknown."""
if text not in self.known_texts: if text not in self.known_texts:
return [float(1.0)] * 9 + [float(0.0)] return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)]
return [float(1.0)] * 9 + [float(self.known_texts.index(text))] return [float(1.0)] * (self.dimensionality - 1) + [
float(self.known_texts.index(text))
]
class AngularTwoDimensionalEmbeddings(Embeddings): class AngularTwoDimensionalEmbeddings(Embeddings):

View File

@ -9,6 +9,7 @@ from qdrant_client.http import models as rest
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.vectorstores import Qdrant from langchain.vectorstores import Qdrant
from langchain.vectorstores.qdrant import QdrantException
from tests.integration_tests.vectorstores.fake_embeddings import ( from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings, ConsistentFakeEmbeddings,
) )
@ -537,3 +538,148 @@ def test_qdrant_similarity_search_with_relevance_scores(
assert all( assert all(
(1 >= score or np.isclose(score, 1)) and score >= 0 for _, score in output (1 >= score or np.isclose(score, 1)) and score >= 0 for _, score in output
) )
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
def test_qdrant_from_texts_reuses_same_collection(vector_name: Optional[str]) -> None:
"""Test if Qdrant.from_texts reuses the same collection"""
from qdrant_client import QdrantClient
collection_name = "test"
embeddings = ConsistentFakeEmbeddings()
with tempfile.TemporaryDirectory() as tmpdir:
vec_store = Qdrant.from_texts(
["lorem", "ipsum", "dolor", "sit", "amet"],
embeddings,
collection_name=collection_name,
path=str(tmpdir),
vector_name=vector_name,
)
del vec_store
vec_store = Qdrant.from_texts(
["foo", "bar"],
embeddings,
collection_name=collection_name,
path=str(tmpdir),
vector_name=vector_name,
)
del vec_store
client = QdrantClient(path=str(tmpdir))
assert 7 == client.count(collection_name).count
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
def test_qdrant_from_texts_raises_error_on_different_dimensionality(
vector_name: Optional[str],
) -> None:
"""Test if Qdrant.from_texts raises an exception if dimensionality does not match"""
collection_name = "test"
with tempfile.TemporaryDirectory() as tmpdir:
vec_store = Qdrant.from_texts(
["lorem", "ipsum", "dolor", "sit", "amet"],
ConsistentFakeEmbeddings(dimensionality=10),
collection_name=collection_name,
path=str(tmpdir),
vector_name=vector_name,
)
del vec_store
with pytest.raises(QdrantException):
Qdrant.from_texts(
["foo", "bar"],
ConsistentFakeEmbeddings(dimensionality=5),
collection_name=collection_name,
path=str(tmpdir),
vector_name=vector_name,
)
@pytest.mark.parametrize(
["first_vector_name", "second_vector_name"],
[
(None, "custom-vector"),
("custom-vector", None),
("my-first-vector", "my-second_vector"),
],
)
def test_qdrant_from_texts_raises_error_on_different_vector_name(
first_vector_name: Optional[str],
second_vector_name: Optional[str],
) -> None:
"""Test if Qdrant.from_texts raises an exception if vector name does not match"""
collection_name = "test"
with tempfile.TemporaryDirectory() as tmpdir:
vec_store = Qdrant.from_texts(
["lorem", "ipsum", "dolor", "sit", "amet"],
ConsistentFakeEmbeddings(dimensionality=10),
collection_name=collection_name,
path=str(tmpdir),
vector_name=first_vector_name,
)
del vec_store
with pytest.raises(QdrantException):
Qdrant.from_texts(
["foo", "bar"],
ConsistentFakeEmbeddings(dimensionality=5),
collection_name=collection_name,
path=str(tmpdir),
vector_name=second_vector_name,
)
def test_qdrant_from_texts_raises_error_on_different_distance() -> None:
"""Test if Qdrant.from_texts raises an exception if distance does not match"""
collection_name = "test"
with tempfile.TemporaryDirectory() as tmpdir:
vec_store = Qdrant.from_texts(
["lorem", "ipsum", "dolor", "sit", "amet"],
ConsistentFakeEmbeddings(dimensionality=10),
collection_name=collection_name,
path=str(tmpdir),
distance_func="Cosine",
)
del vec_store
with pytest.raises(QdrantException):
Qdrant.from_texts(
["foo", "bar"],
ConsistentFakeEmbeddings(dimensionality=5),
collection_name=collection_name,
path=str(tmpdir),
distance_func="Euclid",
)
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
def test_qdrant_from_texts_recreates_collection_on_force_recreate(
vector_name: Optional[str],
) -> None:
"""Test if Qdrant.from_texts recreates the collection even if config mismatches"""
from qdrant_client import QdrantClient
collection_name = "test"
with tempfile.TemporaryDirectory() as tmpdir:
vec_store = Qdrant.from_texts(
["lorem", "ipsum", "dolor", "sit", "amet"],
ConsistentFakeEmbeddings(dimensionality=10),
collection_name=collection_name,
path=str(tmpdir),
vector_name=vector_name,
)
del vec_store
vec_store = Qdrant.from_texts(
["foo", "bar"],
ConsistentFakeEmbeddings(dimensionality=5),
collection_name=collection_name,
path=str(tmpdir),
vector_name=vector_name,
force_recreate=True,
)
del vec_store
client = QdrantClient(path=str(tmpdir))
assert 2 == client.count(collection_name).count