mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
Support named vectors in Qdrant (#6871)
# Description This PR makes it possible to use named vectors from Qdrant in Langchain. That was requested multiple times, as people want to reuse externally created collections in Langchain. It doesn't change anything for the existing applications. The changes were covered with some integration tests and included in the docs. ## Example ```python Qdrant.from_documents( docs, embeddings, location=":memory:", collection_name="my_documents", vector_name="custom_vector", ) ``` ### Issue: #2594 Tagging @rlancemartin & @eyurtsev. I'd appreciate your review.
This commit is contained in:
parent
9ca1cf003c
commit
140ba682f1
@ -626,6 +626,44 @@
|
||||
"source": [
|
||||
"## Customizing Qdrant\n",
|
||||
"\n",
|
||||
"There are some options to use an existing Qdrant collection within your Langchain application. In such cases you may need to define how to map Qdrant point into the Langchain `Document`.\n",
|
||||
"\n",
|
||||
"### Named vectors\n",
|
||||
"\n",
|
||||
"Qdrant supports [multiple vectors per point](https://qdrant.tech/documentation/concepts/collections/#collection-with-multiple-vectors) by named vectors. Langchain requires just a single embedding per document and, by default, uses a single vector. However, if you work with a collection created externally or want to have the named vector used, you can configure it by providing its name.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Qdrant.from_documents(\n",
|
||||
" docs,\n",
|
||||
" embeddings,\n",
|
||||
" location=\":memory:\",\n",
|
||||
" collection_name=\"my_documents_2\",\n",
|
||||
" vector_name=\"custom_vector\",\n",
|
||||
")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"As a Langchain user, you won't see any difference whether you use named vectors or not. Qdrant integration will handle the conversion under the hood."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Metadata\n",
|
||||
"\n",
|
||||
"Qdrant stores your vector embeddings along with the optional JSON-like payload. Payloads are optional, but since LangChain assumes the embeddings are generated from the documents, we keep the context data, so you can extract the original texts as well.\n",
|
||||
"\n",
|
||||
"By default, your document is going to be stored in the following payload structure:\n",
|
||||
@ -639,8 +677,11 @@
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"You can, however, decide to use different keys for the page content and metadata. That's useful if you already have a collection that you'd like to reuse. You can always change the "
|
||||
]
|
||||
"You can, however, decide to use different keys for the page content and metadata. That's useful if you already have a collection that you'd like to reuse."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
@ -52,6 +52,7 @@ class Qdrant(VectorStore):
|
||||
|
||||
CONTENT_KEY = "page_content"
|
||||
METADATA_KEY = "metadata"
|
||||
VECTOR_NAME = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -60,6 +61,7 @@ class Qdrant(VectorStore):
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
content_payload_key: str = CONTENT_KEY,
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
vector_name: Optional[str] = VECTOR_NAME,
|
||||
embedding_function: Optional[Callable] = None, # deprecated
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
@ -94,6 +96,7 @@ class Qdrant(VectorStore):
|
||||
self.collection_name = collection_name
|
||||
self.content_payload_key = content_payload_key or self.CONTENT_KEY
|
||||
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
|
||||
self.vector_name = vector_name or self.VECTOR_NAME
|
||||
|
||||
if embedding_function is not None:
|
||||
warnings.warn(
|
||||
@ -143,20 +146,26 @@ class Qdrant(VectorStore):
|
||||
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
|
||||
batch_ids = list(islice(ids_iterator, batch_size))
|
||||
|
||||
self.client.upsert(
|
||||
collection_name=self.collection_name,
|
||||
points=rest.Batch.construct(
|
||||
ids=batch_ids,
|
||||
vectors=self._embed_texts(batch_texts),
|
||||
payloads=self._build_payloads(
|
||||
batch_texts,
|
||||
batch_metadatas,
|
||||
self.content_payload_key,
|
||||
self.metadata_payload_key,
|
||||
),
|
||||
# Generate the embeddings for all the texts in a batch
|
||||
batch_embeddings = self._embed_texts(batch_texts)
|
||||
if self.vector_name is not None:
|
||||
batch_embeddings = { # type: ignore[assignment]
|
||||
self.vector_name: batch_embeddings
|
||||
}
|
||||
|
||||
points = rest.Batch.construct(
|
||||
ids=batch_ids,
|
||||
vectors=batch_embeddings,
|
||||
payloads=self._build_payloads(
|
||||
batch_texts,
|
||||
batch_metadatas,
|
||||
self.content_payload_key,
|
||||
self.metadata_payload_key,
|
||||
),
|
||||
)
|
||||
|
||||
self.client.upsert(collection_name=self.collection_name, points=points)
|
||||
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
return added_ids
|
||||
@ -315,7 +324,6 @@ class Qdrant(VectorStore):
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
|
||||
results = self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
@ -373,7 +381,6 @@ class Qdrant(VectorStore):
|
||||
distance in float for each.
|
||||
Lower score represents more similarity.
|
||||
"""
|
||||
|
||||
if filter is not None and isinstance(filter, dict):
|
||||
warnings.warn(
|
||||
"Using dict as a `filter` is deprecated. Please use qdrant-client "
|
||||
@ -384,9 +391,14 @@ class Qdrant(VectorStore):
|
||||
qdrant_filter = self._qdrant_filter_from_dict(filter)
|
||||
else:
|
||||
qdrant_filter = filter
|
||||
|
||||
query_vector = embedding
|
||||
if self.vector_name is not None:
|
||||
query_vector = (self.vector_name, embedding) # type: ignore[assignment]
|
||||
|
||||
results = self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
query_vector=embedding,
|
||||
query_vector=query_vector,
|
||||
query_filter=qdrant_filter,
|
||||
search_params=search_params,
|
||||
limit=k,
|
||||
@ -454,18 +466,26 @@ class Qdrant(VectorStore):
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
query_embedding = self._embed_query(query)
|
||||
query_vector = query_embedding
|
||||
if self.vector_name is not None:
|
||||
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
|
||||
|
||||
embedding = self._embed_query(query)
|
||||
results = self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
query_vector=embedding,
|
||||
query_vector=query_vector,
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
limit=fetch_k,
|
||||
)
|
||||
embeddings = [result.vector for result in results]
|
||||
embeddings = [
|
||||
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
|
||||
if self.vector_name is not None
|
||||
else result.vector
|
||||
for result in results
|
||||
]
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
|
||||
np.array(query_embedding), embeddings, k=k, lambda_mult=lambda_mult
|
||||
)
|
||||
return [
|
||||
self._document_from_scored_point(
|
||||
@ -496,6 +516,7 @@ class Qdrant(VectorStore):
|
||||
distance_func: str = "Cosine",
|
||||
content_payload_key: str = CONTENT_KEY,
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
vector_name: Optional[str] = VECTOR_NAME,
|
||||
batch_size: int = 64,
|
||||
shard_number: Optional[int] = None,
|
||||
replication_factor: Optional[int] = None,
|
||||
@ -558,6 +579,9 @@ class Qdrant(VectorStore):
|
||||
metadata_payload_key:
|
||||
A payload key used to store the metadata of the document.
|
||||
Default: "metadata"
|
||||
vector_name:
|
||||
Name of the vector to be used internally in Qdrant.
|
||||
Default: None
|
||||
batch_size:
|
||||
How many vectors upload per-request.
|
||||
Default: 64
|
||||
@ -638,12 +662,21 @@ class Qdrant(VectorStore):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
vectors_config = rest.VectorParams(
|
||||
size=vector_size,
|
||||
distance=rest.Distance[distance_func],
|
||||
)
|
||||
|
||||
# If vector name was provided, we're going to use the named vectors feature
|
||||
# with just a single vector.
|
||||
if vector_name is not None:
|
||||
vectors_config = { # type: ignore[assignment]
|
||||
vector_name: vectors_config,
|
||||
}
|
||||
|
||||
client.recreate_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=rest.VectorParams(
|
||||
size=vector_size,
|
||||
distance=rest.Distance[distance_func],
|
||||
),
|
||||
vectors_config=vectors_config,
|
||||
shard_number=shard_number,
|
||||
replication_factor=replication_factor,
|
||||
write_consistency_factor=write_consistency_factor,
|
||||
@ -666,27 +699,31 @@ class Qdrant(VectorStore):
|
||||
|
||||
# Generate the embeddings for all the texts in a batch
|
||||
batch_embeddings = embedding.embed_documents(batch_texts)
|
||||
if vector_name is not None:
|
||||
batch_embeddings = { # type: ignore[assignment]
|
||||
vector_name: batch_embeddings
|
||||
}
|
||||
|
||||
client.upsert(
|
||||
collection_name=collection_name,
|
||||
points=rest.Batch.construct(
|
||||
ids=batch_ids,
|
||||
vectors=batch_embeddings,
|
||||
payloads=cls._build_payloads(
|
||||
batch_texts,
|
||||
batch_metadatas,
|
||||
content_payload_key,
|
||||
metadata_payload_key,
|
||||
),
|
||||
points = rest.Batch.construct(
|
||||
ids=batch_ids,
|
||||
vectors=batch_embeddings,
|
||||
payloads=cls._build_payloads(
|
||||
batch_texts,
|
||||
batch_metadatas,
|
||||
content_payload_key,
|
||||
metadata_payload_key,
|
||||
),
|
||||
)
|
||||
|
||||
client.upsert(collection_name=collection_name, points=points)
|
||||
|
||||
return cls(
|
||||
client=client,
|
||||
collection_name=collection_name,
|
||||
embeddings=embedding,
|
||||
content_payload_key=content_payload_key,
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
vector_name=vector_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -14,15 +14,8 @@ from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
@pytest.mark.parametrize(
|
||||
["content_payload_key", "metadata_payload_key"],
|
||||
[
|
||||
(Qdrant.CONTENT_KEY, Qdrant.METADATA_KEY),
|
||||
("foo", "bar"),
|
||||
(Qdrant.CONTENT_KEY, "bar"),
|
||||
("foo", Qdrant.METADATA_KEY),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
||||
def test_qdrant_similarity_search(
|
||||
batch_size: int, content_payload_key: str, metadata_payload_key: str
|
||||
) -> None:
|
||||
@ -41,15 +34,8 @@ def test_qdrant_similarity_search(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
@pytest.mark.parametrize(
|
||||
["content_payload_key", "metadata_payload_key"],
|
||||
[
|
||||
(Qdrant.CONTENT_KEY, Qdrant.METADATA_KEY),
|
||||
("foo", "bar"),
|
||||
(Qdrant.CONTENT_KEY, "bar"),
|
||||
("foo", Qdrant.METADATA_KEY),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
||||
def test_qdrant_similarity_search_by_vector(
|
||||
batch_size: int, content_payload_key: str, metadata_payload_key: str
|
||||
) -> None:
|
||||
@ -69,15 +55,8 @@ def test_qdrant_similarity_search_by_vector(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
@pytest.mark.parametrize(
|
||||
["content_payload_key", "metadata_payload_key"],
|
||||
[
|
||||
(Qdrant.CONTENT_KEY, Qdrant.METADATA_KEY),
|
||||
("foo", "bar"),
|
||||
(Qdrant.CONTENT_KEY, "bar"),
|
||||
("foo", Qdrant.METADATA_KEY),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
||||
def test_qdrant_similarity_search_with_score_by_vector(
|
||||
batch_size: int, content_payload_key: str, metadata_payload_key: str
|
||||
) -> None:
|
||||
@ -100,11 +79,16 @@ def test_qdrant_similarity_search_with_score_by_vector(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
def test_qdrant_add_documents(batch_size: int) -> None:
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
def test_qdrant_add_documents(batch_size: int, vector_name: Optional[str]) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch: Qdrant = Qdrant.from_texts(
|
||||
texts, ConsistentFakeEmbeddings(), location=":memory:", batch_size=batch_size
|
||||
texts,
|
||||
ConsistentFakeEmbeddings(),
|
||||
location=":memory:",
|
||||
batch_size=batch_size,
|
||||
vector_name=vector_name,
|
||||
)
|
||||
|
||||
new_texts = ["foobar", "foobaz"]
|
||||
@ -122,6 +106,7 @@ def test_qdrant_add_documents(batch_size: int) -> None:
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
def test_qdrant_add_texts_returns_all_ids(batch_size: int) -> None:
|
||||
"""Test end to end Qdrant.add_texts returns unique ids."""
|
||||
docsearch: Qdrant = Qdrant.from_texts(
|
||||
["foobar"],
|
||||
ConsistentFakeEmbeddings(),
|
||||
@ -135,15 +120,8 @@ def test_qdrant_add_texts_returns_all_ids(batch_size: int) -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
@pytest.mark.parametrize(
|
||||
["content_payload_key", "metadata_payload_key"],
|
||||
[
|
||||
(Qdrant.CONTENT_KEY, Qdrant.METADATA_KEY),
|
||||
("test_content", "test_payload"),
|
||||
(Qdrant.CONTENT_KEY, "payload_test"),
|
||||
("content_test", Qdrant.METADATA_KEY),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
||||
def test_qdrant_with_metadatas(
|
||||
batch_size: int, content_payload_key: str, metadata_payload_key: str
|
||||
) -> None:
|
||||
@ -190,7 +168,10 @@ def test_qdrant_similarity_search_filters(batch_size: int) -> None:
|
||||
]
|
||||
|
||||
|
||||
def test_qdrant_similarity_search_with_relevance_score_no_threshold() -> None:
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
def test_qdrant_similarity_search_with_relevance_score_no_threshold(
|
||||
vector_name: Optional[str],
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [
|
||||
@ -202,6 +183,7 @@ def test_qdrant_similarity_search_with_relevance_score_no_threshold() -> None:
|
||||
ConsistentFakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
location=":memory:",
|
||||
vector_name=vector_name,
|
||||
)
|
||||
output = docsearch.similarity_search_with_relevance_scores(
|
||||
"foo", k=3, score_threshold=None
|
||||
@ -302,17 +284,14 @@ def test_qdrant_similarity_search_filters_with_qdrant_filters() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
@pytest.mark.parametrize(
|
||||
["content_payload_key", "metadata_payload_key"],
|
||||
[
|
||||
(Qdrant.CONTENT_KEY, Qdrant.METADATA_KEY),
|
||||
("test_content", "test_payload"),
|
||||
(Qdrant.CONTENT_KEY, "payload_test"),
|
||||
("content_test", Qdrant.METADATA_KEY),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "test_content"])
|
||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "test_metadata"])
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
def test_qdrant_max_marginal_relevance_search(
|
||||
batch_size: int, content_payload_key: str, metadata_payload_key: str
|
||||
batch_size: int,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
vector_name: Optional[str],
|
||||
) -> None:
|
||||
"""Test end to end construction and MRR search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
@ -325,6 +304,7 @@ def test_qdrant_max_marginal_relevance_search(
|
||||
content_payload_key=content_payload_key,
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
batch_size=batch_size,
|
||||
vector_name=vector_name,
|
||||
)
|
||||
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
|
||||
assert output == [
|
||||
@ -344,6 +324,7 @@ def test_qdrant_max_marginal_relevance_search(
|
||||
def test_qdrant_embedding_interface(
|
||||
embeddings: Optional[Embeddings], embedding_function: Optional[Callable]
|
||||
) -> None:
|
||||
"""Test Qdrant may accept different types for embeddings."""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
client = QdrantClient(":memory:")
|
||||
@ -364,9 +345,10 @@ def test_qdrant_embedding_interface(
|
||||
(None, None),
|
||||
],
|
||||
)
|
||||
def test_qdrant_embedding_interface_raises(
|
||||
def test_qdrant_embedding_interface_raises_value_error(
|
||||
embeddings: Optional[Embeddings], embedding_function: Optional[Callable]
|
||||
) -> None:
|
||||
"""Test Qdrant requires only one method for embeddings."""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
client = QdrantClient(":memory:")
|
||||
@ -381,21 +363,24 @@ def test_qdrant_embedding_interface_raises(
|
||||
)
|
||||
|
||||
|
||||
def test_qdrant_stores_duplicated_texts() -> None:
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
def test_qdrant_add_texts_stores_duplicated_texts(vector_name: Optional[str]) -> None:
|
||||
"""Test end to end Qdrant.add_texts stores duplicated texts separately."""
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
client = QdrantClient(":memory:")
|
||||
collection_name = "test"
|
||||
client.recreate_collection(
|
||||
collection_name,
|
||||
vectors_config=rest.VectorParams(size=10, distance=rest.Distance.COSINE),
|
||||
)
|
||||
vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE)
|
||||
if vector_name is not None:
|
||||
vectors_config = {vector_name: vectors_config} # type: ignore[assignment]
|
||||
client.recreate_collection(collection_name, vectors_config=vectors_config)
|
||||
|
||||
vec_store = Qdrant(
|
||||
client,
|
||||
collection_name,
|
||||
embeddings=ConsistentFakeEmbeddings(),
|
||||
vector_name=vector_name,
|
||||
)
|
||||
ids = vec_store.add_texts(["abc", "abc"], [{"a": 1}, {"a": 2}])
|
||||
|
||||
@ -404,6 +389,7 @@ def test_qdrant_stores_duplicated_texts() -> None:
|
||||
|
||||
|
||||
def test_qdrant_from_texts_stores_duplicated_texts() -> None:
|
||||
"""Test end to end Qdrant.from_texts stores duplicated texts separately."""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@ -420,7 +406,11 @@ def test_qdrant_from_texts_stores_duplicated_texts() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
def test_qdrant_from_texts_stores_ids(batch_size: int) -> None:
|
||||
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
|
||||
def test_qdrant_from_texts_stores_ids(
|
||||
batch_size: int, vector_name: Optional[str]
|
||||
) -> None:
|
||||
"""Test end to end Qdrant.from_texts stores provided ids."""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@ -435,6 +425,7 @@ def test_qdrant_from_texts_stores_ids(batch_size: int) -> None:
|
||||
collection_name="test",
|
||||
path=str(tmpdir),
|
||||
batch_size=batch_size,
|
||||
vector_name=vector_name,
|
||||
)
|
||||
del vec_store
|
||||
|
||||
@ -446,6 +437,7 @@ def test_qdrant_from_texts_stores_ids(batch_size: int) -> None:
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
def test_qdrant_add_texts_stores_ids(batch_size: int) -> None:
|
||||
"""Test end to end Qdrant.add_texts stores provided ids."""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
ids = [
|
||||
@ -467,3 +459,57 @@ def test_qdrant_add_texts_stores_ids(batch_size: int) -> None:
|
||||
assert 2 == client.count("test").count
|
||||
stored_ids = [point.id for point in client.scroll("test")[0]]
|
||||
assert set(ids) == set(stored_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("vector_name", ["custom-vector"])
|
||||
def test_qdrant_from_texts_stores_embeddings_as_named_vectors(vector_name: str) -> None:
|
||||
"""Test end to end Qdrant.from_texts stores named vectors if name is provided."""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
collection_name = "test"
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
vec_store = Qdrant.from_texts(
|
||||
["lorem", "ipsum", "dolor", "sit", "amet"],
|
||||
ConsistentFakeEmbeddings(),
|
||||
collection_name=collection_name,
|
||||
path=str(tmpdir),
|
||||
vector_name=vector_name,
|
||||
)
|
||||
del vec_store
|
||||
|
||||
client = QdrantClient(path=str(tmpdir))
|
||||
assert 5 == client.count("test").count
|
||||
assert all(
|
||||
vector_name in point.vector # type: ignore[operator]
|
||||
for point in client.scroll(collection_name, with_vectors=True)[0]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("vector_name", ["custom-vector"])
|
||||
def test_qdrant_add_texts_stores_embeddings_as_named_vectors(vector_name: str) -> None:
|
||||
"""Test end to end Qdrant.add_texts stores named vectors if name is provided."""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
collection_name = "test"
|
||||
|
||||
client = QdrantClient(":memory:")
|
||||
client.recreate_collection(
|
||||
collection_name,
|
||||
vectors_config={
|
||||
vector_name: rest.VectorParams(size=10, distance=rest.Distance.COSINE)
|
||||
},
|
||||
)
|
||||
|
||||
vec_store = Qdrant(
|
||||
client,
|
||||
collection_name,
|
||||
ConsistentFakeEmbeddings(),
|
||||
vector_name=vector_name,
|
||||
)
|
||||
vec_store.add_texts(["lorem", "ipsum", "dolor", "sit", "amet"])
|
||||
|
||||
assert 5 == client.count("test").count
|
||||
assert all(
|
||||
vector_name in point.vector # type: ignore[operator]
|
||||
for point in client.scroll(collection_name, with_vectors=True)[0]
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user