mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 15:16:21 +00:00
Reuse the existing collection if configured properly in Qdrant.from_texts (#7530)
This PR changes the behavior of `Qdrant.from_texts` so the collection is reused if not requested to recreate it. Previously, calling `Qdrant.from_texts` or `Qdrant.from_documents` resulted in removing the old data which was confusing for many.
This commit is contained in:
parent
6674b33cf5
commit
1f83b5f47e
File diff suppressed because it is too large
Load Diff
@ -34,6 +34,10 @@ if TYPE_CHECKING:
|
|||||||
MetadataFilter = Union[DictFilter, common_types.Filter]
|
MetadataFilter = Union[DictFilter, common_types.Filter]
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantException(Exception):
|
||||||
|
"""Base class for all the Qdrant related exceptions"""
|
||||||
|
|
||||||
|
|
||||||
class Qdrant(VectorStore):
|
class Qdrant(VectorStore):
|
||||||
"""Wrapper around Qdrant vector database.
|
"""Wrapper around Qdrant vector database.
|
||||||
|
|
||||||
@ -552,6 +556,7 @@ class Qdrant(VectorStore):
|
|||||||
wal_config: Optional[common_types.WalConfigDiff] = None,
|
wal_config: Optional[common_types.WalConfigDiff] = None,
|
||||||
quantization_config: Optional[common_types.QuantizationConfig] = None,
|
quantization_config: Optional[common_types.QuantizationConfig] = None,
|
||||||
init_from: Optional[common_types.InitFrom] = None,
|
init_from: Optional[common_types.InitFrom] = None,
|
||||||
|
force_recreate: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Qdrant:
|
) -> Qdrant:
|
||||||
"""Construct Qdrant wrapper from a list of texts.
|
"""Construct Qdrant wrapper from a list of texts.
|
||||||
@ -636,6 +641,8 @@ class Qdrant(VectorStore):
|
|||||||
Params for quantization, if None - quantization will be disabled
|
Params for quantization, if None - quantization will be disabled
|
||||||
init_from:
|
init_from:
|
||||||
Use data stored in another collection to initialize this collection
|
Use data stored in another collection to initialize this collection
|
||||||
|
force_recreate:
|
||||||
|
Force recreating the collection
|
||||||
**kwargs:
|
**kwargs:
|
||||||
Additional arguments passed directly into REST client initialization
|
Additional arguments passed directly into REST client initialization
|
||||||
|
|
||||||
@ -663,7 +670,9 @@ class Qdrant(VectorStore):
|
|||||||
"Please install it with `pip install qdrant-client`."
|
"Please install it with `pip install qdrant-client`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from grpc import RpcError
|
||||||
from qdrant_client.http import models as rest
|
from qdrant_client.http import models as rest
|
||||||
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
|
|
||||||
# Just do a single quick embedding to get vector size
|
# Just do a single quick embedding to get vector size
|
||||||
partial_embeddings = embedding.embed_documents(texts[:1])
|
partial_embeddings = embedding.embed_documents(texts[:1])
|
||||||
@ -687,62 +696,98 @@ class Qdrant(VectorStore):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
vectors_config = rest.VectorParams(
|
try:
|
||||||
size=vector_size,
|
# Skip any validation in case of forced collection recreate.
|
||||||
distance=rest.Distance[distance_func],
|
if force_recreate:
|
||||||
)
|
raise ValueError
|
||||||
|
|
||||||
# If vector name was provided, we're going to use the named vectors feature
|
# Get the vector configuration of the existing collection and vector, if it
|
||||||
# with just a single vector.
|
# was specified. If the old configuration does not match the current one,
|
||||||
if vector_name is not None:
|
# an exception is being thrown.
|
||||||
vectors_config = { # type: ignore[assignment]
|
collection_info = client.get_collection(collection_name=collection_name)
|
||||||
vector_name: vectors_config,
|
current_vector_config = collection_info.config.params.vectors
|
||||||
}
|
if isinstance(current_vector_config, dict) and vector_name is not None:
|
||||||
|
if vector_name not in current_vector_config:
|
||||||
|
raise QdrantException(
|
||||||
|
f"Existing Qdrant collection {collection_name} does not "
|
||||||
|
f"contain vector named {vector_name}. Did you mean one of the "
|
||||||
|
f"existing vectors: {', '.join(current_vector_config.keys())}? "
|
||||||
|
f"If you want to recreate the collection, set `force_recreate` "
|
||||||
|
f"parameter to `True`."
|
||||||
|
)
|
||||||
|
current_vector_config = current_vector_config.get(
|
||||||
|
vector_name
|
||||||
|
) # type: ignore[assignment]
|
||||||
|
elif isinstance(current_vector_config, dict) and vector_name is None:
|
||||||
|
raise QdrantException(
|
||||||
|
f"Existing Qdrant collection {collection_name} uses named vectors. "
|
||||||
|
f"If you want to reuse it, please set `vector_name` to any of the "
|
||||||
|
f"existing named vectors: "
|
||||||
|
f"{', '.join(current_vector_config.keys())}." # noqa
|
||||||
|
f"If you want to recreate the collection, set `force_recreate` "
|
||||||
|
f"parameter to `True`."
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
not isinstance(current_vector_config, dict) and vector_name is not None
|
||||||
|
):
|
||||||
|
raise QdrantException(
|
||||||
|
f"Existing Qdrant collection {collection_name} doesn't use named "
|
||||||
|
f"vectors. If you want to reuse it, please set `vector_name` to "
|
||||||
|
f"`None`. If you want to recreate the collection, set "
|
||||||
|
f"`force_recreate` parameter to `True`."
|
||||||
|
)
|
||||||
|
|
||||||
client.recreate_collection(
|
# Check if the vector configuration has the same dimensionality.
|
||||||
collection_name=collection_name,
|
if current_vector_config.size != vector_size: # type: ignore[union-attr]
|
||||||
vectors_config=vectors_config,
|
raise QdrantException(
|
||||||
shard_number=shard_number,
|
f"Existing Qdrant collection is configured for vectors with "
|
||||||
replication_factor=replication_factor,
|
f"{current_vector_config.size} " # type: ignore[union-attr]
|
||||||
write_consistency_factor=write_consistency_factor,
|
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
|
||||||
on_disk_payload=on_disk_payload,
|
f"If you want to recreate the collection, set `force_recreate` "
|
||||||
hnsw_config=hnsw_config,
|
f"parameter to `True`."
|
||||||
optimizers_config=optimizers_config,
|
)
|
||||||
wal_config=wal_config,
|
|
||||||
quantization_config=quantization_config,
|
|
||||||
init_from=init_from,
|
|
||||||
timeout=timeout, # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
texts_iterator = iter(texts)
|
current_distance_func = (
|
||||||
metadatas_iterator = iter(metadatas or [])
|
current_vector_config.distance.name.upper() # type: ignore[union-attr]
|
||||||
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
|
)
|
||||||
while batch_texts := list(islice(texts_iterator, batch_size)):
|
if current_distance_func != distance_func:
|
||||||
# Take the corresponding metadata and id for each text in a batch
|
raise QdrantException(
|
||||||
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
|
f"Existing Qdrant collection is configured for "
|
||||||
batch_ids = list(islice(ids_iterator, batch_size))
|
f"{current_vector_config.distance} " # type: ignore[union-attr]
|
||||||
|
f"similarity. Please set `distance_func` parameter to "
|
||||||
# Generate the embeddings for all the texts in a batch
|
f"`{distance_func}` if you want to reuse it. If you want to "
|
||||||
batch_embeddings = embedding.embed_documents(batch_texts)
|
f"recreate the collection, set `force_recreate` parameter to "
|
||||||
if vector_name is not None:
|
f"`True`."
|
||||||
batch_embeddings = { # type: ignore[assignment]
|
)
|
||||||
vector_name: batch_embeddings
|
except (UnexpectedResponse, RpcError, ValueError):
|
||||||
}
|
vectors_config = rest.VectorParams(
|
||||||
|
size=vector_size,
|
||||||
points = rest.Batch.construct(
|
distance=rest.Distance[distance_func],
|
||||||
ids=batch_ids,
|
|
||||||
vectors=batch_embeddings,
|
|
||||||
payloads=cls._build_payloads(
|
|
||||||
batch_texts,
|
|
||||||
batch_metadatas,
|
|
||||||
content_payload_key,
|
|
||||||
metadata_payload_key,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
client.upsert(collection_name=collection_name, points=points)
|
# If vector name was provided, we're going to use the named vectors feature
|
||||||
|
# with just a single vector.
|
||||||
|
if vector_name is not None:
|
||||||
|
vectors_config = { # type: ignore[assignment]
|
||||||
|
vector_name: vectors_config,
|
||||||
|
}
|
||||||
|
|
||||||
return cls(
|
client.recreate_collection(
|
||||||
|
collection_name=collection_name,
|
||||||
|
vectors_config=vectors_config,
|
||||||
|
shard_number=shard_number,
|
||||||
|
replication_factor=replication_factor,
|
||||||
|
write_consistency_factor=write_consistency_factor,
|
||||||
|
on_disk_payload=on_disk_payload,
|
||||||
|
hnsw_config=hnsw_config,
|
||||||
|
optimizers_config=optimizers_config,
|
||||||
|
wal_config=wal_config,
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
init_from=init_from,
|
||||||
|
timeout=timeout, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
qdrant = cls(
|
||||||
client=client,
|
client=client,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
embeddings=embedding,
|
embeddings=embedding,
|
||||||
@ -751,6 +796,8 @@ class Qdrant(VectorStore):
|
|||||||
distance_strategy=distance_func,
|
distance_strategy=distance_func,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
)
|
)
|
||||||
|
qdrant.add_texts(texts, metadatas, ids, batch_size)
|
||||||
|
return qdrant
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _build_payloads(
|
def _build_payloads(
|
||||||
|
@ -27,8 +27,9 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
|
|||||||
"""Fake embeddings which remember all the texts seen so far to return consistent
|
"""Fake embeddings which remember all the texts seen so far to return consistent
|
||||||
vectors for the same texts."""
|
vectors for the same texts."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, dimensionality: int = 10) -> None:
|
||||||
self.known_texts: List[str] = []
|
self.known_texts: List[str] = []
|
||||||
|
self.dimensionality = dimensionality
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Return consistent embeddings for each text seen so far."""
|
"""Return consistent embeddings for each text seen so far."""
|
||||||
@ -36,7 +37,9 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
|
|||||||
for text in texts:
|
for text in texts:
|
||||||
if text not in self.known_texts:
|
if text not in self.known_texts:
|
||||||
self.known_texts.append(text)
|
self.known_texts.append(text)
|
||||||
vector = [float(1.0)] * 9 + [float(self.known_texts.index(text))]
|
vector = [float(1.0)] * (self.dimensionality - 1) + [
|
||||||
|
float(self.known_texts.index(text))
|
||||||
|
]
|
||||||
out_vectors.append(vector)
|
out_vectors.append(vector)
|
||||||
return out_vectors
|
return out_vectors
|
||||||
|
|
||||||
@ -44,8 +47,10 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
|
|||||||
"""Return consistent embeddings for the text, if seen before, or a constant
|
"""Return consistent embeddings for the text, if seen before, or a constant
|
||||||
one if the text is unknown."""
|
one if the text is unknown."""
|
||||||
if text not in self.known_texts:
|
if text not in self.known_texts:
|
||||||
return [float(1.0)] * 9 + [float(0.0)]
|
return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)]
|
||||||
return [float(1.0)] * 9 + [float(self.known_texts.index(text))]
|
return [float(1.0)] * (self.dimensionality - 1) + [
|
||||||
|
float(self.known_texts.index(text))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class AngularTwoDimensionalEmbeddings(Embeddings):
|
class AngularTwoDimensionalEmbeddings(Embeddings):
|
||||||
|
@ -9,6 +9,7 @@ from qdrant_client.http import models as rest
|
|||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.vectorstores import Qdrant
|
from langchain.vectorstores import Qdrant
|
||||||
|
from langchain.vectorstores.qdrant import QdrantException
|
||||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||||
ConsistentFakeEmbeddings,
|
ConsistentFakeEmbeddings,
|
||||||
)
|
)
|
||||||
@ -537,3 +538,148 @@ def test_qdrant_similarity_search_with_relevance_scores(
|
|||||||
assert all(
|
assert all(
|
||||||
(1 >= score or np.isclose(score, 1)) and score >= 0 for _, score in output
|
(1 >= score or np.isclose(score, 1)) and score >= 0 for _, score in output
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
|
||||||
|
def test_qdrant_from_texts_reuses_same_collection(vector_name: Optional[str]) -> None:
|
||||||
|
"""Test if Qdrant.from_texts reuses the same collection"""
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
|
||||||
|
collection_name = "test"
|
||||||
|
embeddings = ConsistentFakeEmbeddings()
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
vec_store = Qdrant.from_texts(
|
||||||
|
["lorem", "ipsum", "dolor", "sit", "amet"],
|
||||||
|
embeddings,
|
||||||
|
collection_name=collection_name,
|
||||||
|
path=str(tmpdir),
|
||||||
|
vector_name=vector_name,
|
||||||
|
)
|
||||||
|
del vec_store
|
||||||
|
|
||||||
|
vec_store = Qdrant.from_texts(
|
||||||
|
["foo", "bar"],
|
||||||
|
embeddings,
|
||||||
|
collection_name=collection_name,
|
||||||
|
path=str(tmpdir),
|
||||||
|
vector_name=vector_name,
|
||||||
|
)
|
||||||
|
del vec_store
|
||||||
|
|
||||||
|
client = QdrantClient(path=str(tmpdir))
|
||||||
|
assert 7 == client.count(collection_name).count
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
|
||||||
|
def test_qdrant_from_texts_raises_error_on_different_dimensionality(
|
||||||
|
vector_name: Optional[str],
|
||||||
|
) -> None:
|
||||||
|
"""Test if Qdrant.from_texts raises an exception if dimensionality does not match"""
|
||||||
|
collection_name = "test"
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
vec_store = Qdrant.from_texts(
|
||||||
|
["lorem", "ipsum", "dolor", "sit", "amet"],
|
||||||
|
ConsistentFakeEmbeddings(dimensionality=10),
|
||||||
|
collection_name=collection_name,
|
||||||
|
path=str(tmpdir),
|
||||||
|
vector_name=vector_name,
|
||||||
|
)
|
||||||
|
del vec_store
|
||||||
|
|
||||||
|
with pytest.raises(QdrantException):
|
||||||
|
Qdrant.from_texts(
|
||||||
|
["foo", "bar"],
|
||||||
|
ConsistentFakeEmbeddings(dimensionality=5),
|
||||||
|
collection_name=collection_name,
|
||||||
|
path=str(tmpdir),
|
||||||
|
vector_name=vector_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["first_vector_name", "second_vector_name"],
|
||||||
|
[
|
||||||
|
(None, "custom-vector"),
|
||||||
|
("custom-vector", None),
|
||||||
|
("my-first-vector", "my-second_vector"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_qdrant_from_texts_raises_error_on_different_vector_name(
|
||||||
|
first_vector_name: Optional[str],
|
||||||
|
second_vector_name: Optional[str],
|
||||||
|
) -> None:
|
||||||
|
"""Test if Qdrant.from_texts raises an exception if vector name does not match"""
|
||||||
|
collection_name = "test"
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
vec_store = Qdrant.from_texts(
|
||||||
|
["lorem", "ipsum", "dolor", "sit", "amet"],
|
||||||
|
ConsistentFakeEmbeddings(dimensionality=10),
|
||||||
|
collection_name=collection_name,
|
||||||
|
path=str(tmpdir),
|
||||||
|
vector_name=first_vector_name,
|
||||||
|
)
|
||||||
|
del vec_store
|
||||||
|
|
||||||
|
with pytest.raises(QdrantException):
|
||||||
|
Qdrant.from_texts(
|
||||||
|
["foo", "bar"],
|
||||||
|
ConsistentFakeEmbeddings(dimensionality=5),
|
||||||
|
collection_name=collection_name,
|
||||||
|
path=str(tmpdir),
|
||||||
|
vector_name=second_vector_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_qdrant_from_texts_raises_error_on_different_distance() -> None:
|
||||||
|
"""Test if Qdrant.from_texts raises an exception if distance does not match"""
|
||||||
|
collection_name = "test"
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
vec_store = Qdrant.from_texts(
|
||||||
|
["lorem", "ipsum", "dolor", "sit", "amet"],
|
||||||
|
ConsistentFakeEmbeddings(dimensionality=10),
|
||||||
|
collection_name=collection_name,
|
||||||
|
path=str(tmpdir),
|
||||||
|
distance_func="Cosine",
|
||||||
|
)
|
||||||
|
del vec_store
|
||||||
|
|
||||||
|
with pytest.raises(QdrantException):
|
||||||
|
Qdrant.from_texts(
|
||||||
|
["foo", "bar"],
|
||||||
|
ConsistentFakeEmbeddings(dimensionality=5),
|
||||||
|
collection_name=collection_name,
|
||||||
|
path=str(tmpdir),
|
||||||
|
distance_func="Euclid",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("vector_name", [None, "custom-vector"])
|
||||||
|
def test_qdrant_from_texts_recreates_collection_on_force_recreate(
|
||||||
|
vector_name: Optional[str],
|
||||||
|
) -> None:
|
||||||
|
"""Test if Qdrant.from_texts recreates the collection even if config mismatches"""
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
|
||||||
|
collection_name = "test"
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
vec_store = Qdrant.from_texts(
|
||||||
|
["lorem", "ipsum", "dolor", "sit", "amet"],
|
||||||
|
ConsistentFakeEmbeddings(dimensionality=10),
|
||||||
|
collection_name=collection_name,
|
||||||
|
path=str(tmpdir),
|
||||||
|
vector_name=vector_name,
|
||||||
|
)
|
||||||
|
del vec_store
|
||||||
|
|
||||||
|
vec_store = Qdrant.from_texts(
|
||||||
|
["foo", "bar"],
|
||||||
|
ConsistentFakeEmbeddings(dimensionality=5),
|
||||||
|
collection_name=collection_name,
|
||||||
|
path=str(tmpdir),
|
||||||
|
vector_name=vector_name,
|
||||||
|
force_recreate=True,
|
||||||
|
)
|
||||||
|
del vec_store
|
||||||
|
|
||||||
|
client = QdrantClient(path=str(tmpdir))
|
||||||
|
assert 2 == client.count(collection_name).count
|
||||||
|
Loading…
Reference in New Issue
Block a user