mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 23:13:31 +00:00
Add "Astra DB" vector store integration (#12966)
# Astra DB Vector store integration - **Description:** This PR adds a `VectorStore` implementation for DataStax Astra DB using its HTTP API - **Issue:** (no related issue) - **Dependencies:** A new required dependency is `astrapy` (`>=0.5.3`) which was added to pyptoject.toml, optional, as per guidelines - **Tag maintainer:** I recently mentioned to @baskaryan this integration was coming - **Twitter handle:** `@rsprrs` if you want to mention me This PR introduces the `AstraDB` vector store class, extensive integration test coverage, a reworking of the documentation which conflates Cassandra and Astra DB on a single "provider" page and a new, completely reworked vector-store example notebook (common to the Cassandra store, since parts of the flow is shared by the two APIs). I also took care in ensuring docs (and redirects therein) are behaving correctly. All style, linting, typechecks and tests pass as far as the `AstraDB` integration is concerned. I could build the documentation and check it all right (but ran into trouble with the `api_docs_build` makefile target which I could not verify: `Error: Unable to import module 'plan_and_execute.agent_executor' with error: No module named 'langchain_experimental'` was the first of many similar errors) Thank you for a review! Stefano --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -104,6 +104,12 @@ def _import_cassandra() -> Any:
|
||||
return Cassandra
|
||||
|
||||
|
||||
def _import_astradb() -> Any:
|
||||
from langchain.vectorstores.astradb import AstraDB
|
||||
|
||||
return AstraDB
|
||||
|
||||
|
||||
def _import_chroma() -> Any:
|
||||
from langchain.vectorstores.chroma import Chroma
|
||||
|
||||
@@ -443,6 +449,8 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_baiducloud_vector_search()
|
||||
elif name == "Cassandra":
|
||||
return _import_cassandra()
|
||||
elif name == "AstraDB":
|
||||
return _import_astradb()
|
||||
elif name == "Chroma":
|
||||
return _import_chroma()
|
||||
elif name == "Clarifai":
|
||||
@@ -561,6 +569,7 @@ __all__ = [
|
||||
"AzureSearch",
|
||||
"Bagel",
|
||||
"Cassandra",
|
||||
"AstraDB",
|
||||
"Chroma",
|
||||
"Clarifai",
|
||||
"Clickhouse",
|
||||
|
751
libs/langchain/langchain/vectorstores/astradb.py
Normal file
751
libs/langchain/langchain/vectorstores/astradb.py
Normal file
@@ -0,0 +1,751 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.vectorstore import VectorStore
|
||||
from langchain.utils.iter import batch_iterate
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
ADBVST = TypeVar("ADBVST", bound="AstraDB")
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
DocDict = Dict[str, Any] # dicts expressing entries to insert
|
||||
|
||||
# Batch/concurrency default values (if parameters not provided):
|
||||
# Size of batches for bulk insertions:
|
||||
# (20 is the max batch size for the HTTP API at the time of writing)
|
||||
DEFAULT_BATCH_SIZE = 20
|
||||
# Number of threads to insert batches concurrently:
|
||||
DEFAULT_BULK_INSERT_BATCH_CONCURRENCY = 5
|
||||
# Number of threads in a batch to insert pre-existing entries:
|
||||
DEFAULT_BULK_INSERT_OVERWRITE_CONCURRENCY = 10
|
||||
# Number of threads (for deleting multiple rows concurrently):
|
||||
DEFAULT_BULK_DELETE_CONCURRENCY = 20
|
||||
|
||||
|
||||
def _unique_list(lst: List[T], key: Callable[[T], U]) -> List[T]:
|
||||
visited_keys: Set[U] = set()
|
||||
new_lst = []
|
||||
for item in lst:
|
||||
item_key = key(item)
|
||||
if item_key not in visited_keys:
|
||||
visited_keys.add(item_key)
|
||||
new_lst.append(item)
|
||||
return new_lst
|
||||
|
||||
|
||||
class AstraDB(VectorStore):
|
||||
"""Wrapper around DataStax Astra DB for vector-store workloads.
|
||||
|
||||
To use it, you need a recent installation of the `astrapy` library
|
||||
and an Astra DB cloud database.
|
||||
|
||||
For quickstart and details, visit:
|
||||
docs.datastax.com/en/astra/home/astra.html
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores import AstraDB
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = AstraDB(
|
||||
embedding=embeddings,
|
||||
collection_name="my_store",
|
||||
token="AstraCS:...",
|
||||
api_endpoint="https://<DB-ID>-us-east1.apps.astra.datastax.com"
|
||||
)
|
||||
|
||||
vectorstore.add_texts(["Giraffes", "All good here"])
|
||||
results = vectorstore.similarity_search("Everything's ok", k=1)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _filter_to_metadata(filter_dict: Optional[Dict[str, str]]) -> Dict[str, Any]:
|
||||
if filter_dict is None:
|
||||
return {}
|
||||
else:
|
||||
return {f"metadata.{mdk}": mdv for mdk, mdv in filter_dict.items()}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
embedding: Embeddings,
|
||||
collection_name: str,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
|
||||
namespace: Optional[str] = None,
|
||||
metric: Optional[str] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
bulk_insert_batch_concurrency: Optional[int] = None,
|
||||
bulk_insert_overwrite_concurrency: Optional[int] = None,
|
||||
bulk_delete_concurrency: Optional[int] = None,
|
||||
) -> None:
|
||||
try:
|
||||
from astrapy.db import (
|
||||
AstraDB as LibAstraDB,
|
||||
)
|
||||
from astrapy.db import (
|
||||
AstraDBCollection as LibAstraDBCollection,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import a recent astrapy python package. "
|
||||
"Please install it with `pip install --upgrade astrapy`."
|
||||
)
|
||||
"""
|
||||
Create an AstraDB vector store object.
|
||||
|
||||
Args (only keyword-arguments accepted):
|
||||
embedding (Embeddings): embedding function to use.
|
||||
collection_name (str): name of the Astra DB collection to create/use.
|
||||
token (Optional[str]): API token for Astra DB usage.
|
||||
api_endpoint (Optional[str]): full URL to the API endpoint,
|
||||
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
|
||||
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AstraDB' instance.
|
||||
namespace (Optional[str]): namespace (aka keyspace) where the
|
||||
collection is created. Defaults to the database's "default namespace".
|
||||
metric (Optional[str]): similarity function to use out of those
|
||||
available in Astra DB. If left out, it will use Astra DB API's
|
||||
defaults (i.e. "cosine" - but, for performance reasons,
|
||||
"dot_product" is suggested if embeddings are normalized to one).
|
||||
|
||||
Advanced arguments (coming with sensible defaults):
|
||||
batch_size (Optional[int]): Size of batches for bulk insertions.
|
||||
bulk_insert_batch_concurrency (Optional[int]): Number of threads
|
||||
to insert batches concurrently.
|
||||
bulk_insert_overwrite_concurrency (Optional[int]): Number of
|
||||
threads in a batch to insert pre-existing entries.
|
||||
bulk_delete_concurrency (Optional[int]): Number of threads
|
||||
(for deleting multiple rows concurrently).
|
||||
"""
|
||||
|
||||
# Conflicting-arg checks:
|
||||
if astra_db_client is not None:
|
||||
if token is not None or api_endpoint is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass 'astra_db_client' to AstraDB if passing "
|
||||
"'token' and 'api_endpoint'."
|
||||
)
|
||||
|
||||
self.embedding = embedding
|
||||
self.collection_name = collection_name
|
||||
self.token = token
|
||||
self.api_endpoint = api_endpoint
|
||||
self.namespace = namespace
|
||||
# Concurrency settings
|
||||
self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE
|
||||
self.bulk_insert_batch_concurrency: int = (
|
||||
bulk_insert_batch_concurrency or DEFAULT_BULK_INSERT_BATCH_CONCURRENCY
|
||||
)
|
||||
self.bulk_insert_overwrite_concurrency: int = (
|
||||
bulk_insert_overwrite_concurrency
|
||||
or DEFAULT_BULK_INSERT_OVERWRITE_CONCURRENCY
|
||||
)
|
||||
self.bulk_delete_concurrency: int = (
|
||||
bulk_delete_concurrency or DEFAULT_BULK_DELETE_CONCURRENCY
|
||||
)
|
||||
# "vector-related" settings
|
||||
self._embedding_dimension: Optional[int] = None
|
||||
self.metric = metric
|
||||
|
||||
if astra_db_client is not None:
|
||||
self.astra_db = astra_db_client
|
||||
else:
|
||||
self.astra_db = LibAstraDB(
|
||||
token=self.token,
|
||||
api_endpoint=self.api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
self._provision_collection()
|
||||
|
||||
self.collection = LibAstraDBCollection(
|
||||
collection_name=self.collection_name,
|
||||
astra_db=self.astra_db,
|
||||
)
|
||||
|
||||
def _get_embedding_dimension(self) -> int:
|
||||
if self._embedding_dimension is None:
|
||||
self._embedding_dimension = len(
|
||||
self.embedding.embed_query("This is a sample sentence.")
|
||||
)
|
||||
return self._embedding_dimension
|
||||
|
||||
def _drop_collection(self) -> None:
|
||||
"""
|
||||
Drop the collection from storage.
|
||||
|
||||
This is meant as an internal-usage method, no members
|
||||
are set other than actual deletion on the backend.
|
||||
"""
|
||||
_ = self.astra_db.delete_collection(
|
||||
collection_name=self.collection_name,
|
||||
)
|
||||
return None
|
||||
|
||||
def _provision_collection(self) -> None:
|
||||
"""
|
||||
Run the API invocation to create the collection on the backend.
|
||||
|
||||
Internal-usage method, no object members are set,
|
||||
other than working on the underlying actual storage.
|
||||
"""
|
||||
_ = self.astra_db.create_collection(
|
||||
dimension=self._get_embedding_dimension(),
|
||||
collection_name=self.collection_name,
|
||||
metric=self.metric,
|
||||
)
|
||||
return None
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding
|
||||
|
||||
@staticmethod
|
||||
def _dont_flip_the_cos_score(similarity0to1: float) -> float:
|
||||
"""Keep similarity from client unchanged ad it's in [0:1] already."""
|
||||
return similarity0to1
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
The underlying API calls already returns a "score proper",
|
||||
i.e. one in [0, 1] where higher means more *similar*,
|
||||
so here the final score transformation is not reversing the interval:
|
||||
"""
|
||||
return self._dont_flip_the_cos_score
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Empty the collection of all its stored entries."""
|
||||
self._drop_collection()
|
||||
self._provision_collection()
|
||||
return None
|
||||
|
||||
def delete_by_document_id(self, document_id: str) -> bool:
|
||||
"""
|
||||
Remove a single document from the store, given its document_id (str).
|
||||
Return True if a document has indeed been deleted, False if ID not found.
|
||||
"""
|
||||
deletion_response = self.collection.delete(document_id)
|
||||
return ((deletion_response or {}).get("status") or {}).get(
|
||||
"deletedCount", 0
|
||||
) == 1
|
||||
|
||||
def delete(
|
||||
self,
|
||||
ids: Optional[List[str]] = None,
|
||||
concurrency: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> Optional[bool]:
|
||||
"""Delete by vector ids.
|
||||
|
||||
Args:
|
||||
ids (Optional[List[str]]): List of ids to delete.
|
||||
concurrency (Optional[int]): max number of threads issuing
|
||||
single-doc delete requests. Defaults to instance-level setting.
|
||||
|
||||
Returns:
|
||||
Optional[bool]: True if deletion is successful,
|
||||
False otherwise, None if not implemented.
|
||||
"""
|
||||
|
||||
if kwargs:
|
||||
warnings.warn(
|
||||
"Method 'delete' of AstraDB vector store invoked with "
|
||||
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
|
||||
"which will be ignored."
|
||||
)
|
||||
|
||||
if ids is None:
|
||||
raise ValueError("No ids provided to delete.")
|
||||
|
||||
_max_workers = concurrency or self.bulk_delete_concurrency
|
||||
with ThreadPoolExecutor(max_workers=_max_workers) as tpe:
|
||||
_ = list(
|
||||
tpe.map(
|
||||
self.delete_by_document_id,
|
||||
ids,
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
def delete_collection(self) -> None:
|
||||
"""
|
||||
Completely delete the collection from the database (as opposed
|
||||
to 'clear()', which empties it only).
|
||||
Stored data is lost and unrecoverable, resources are freed.
|
||||
Use with caution.
|
||||
"""
|
||||
self._drop_collection()
|
||||
return None
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
*,
|
||||
batch_size: Optional[int] = None,
|
||||
batch_concurrency: Optional[int] = None,
|
||||
overwrite_concurrency: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run texts through the embeddings and add them to the vectorstore.
|
||||
|
||||
If passing explicit ids, those entries whose id is in the store already
|
||||
will be replaced.
|
||||
|
||||
Args:
|
||||
texts (Iterable[str]): Texts to add to the vectorstore.
|
||||
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
|
||||
ids (Optional[List[str]], optional): Optional list of ids.
|
||||
batch_size (Optional[int]): Number of documents in each API call.
|
||||
Check the underlying Astra DB HTTP API specs for the max value
|
||||
(20 at the time of writing this). If not provided, defaults
|
||||
to the instance-level setting.
|
||||
batch_concurrency (Optional[int]): number of threads to process
|
||||
insertion batches concurrently. Defaults to instance-level
|
||||
setting if not provided.
|
||||
overwrite_concurrency (Optional[int]): number of threads to process
|
||||
pre-existing documents in each batch (which require individual
|
||||
API calls). Defaults to instance-level setting if not provided.
|
||||
|
||||
Returns:
|
||||
List[str]: List of ids of the added texts.
|
||||
"""
|
||||
|
||||
if kwargs:
|
||||
warnings.warn(
|
||||
"Method 'add_texts' of AstraDB vector store invoked with "
|
||||
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
|
||||
"which will be ignored."
|
||||
)
|
||||
|
||||
_texts = list(texts)
|
||||
if ids is None:
|
||||
ids = [uuid.uuid4().hex for _ in _texts]
|
||||
if metadatas is None:
|
||||
metadatas = [{} for _ in _texts]
|
||||
#
|
||||
embedding_vectors = self.embedding.embed_documents(_texts)
|
||||
|
||||
documents_to_insert = [
|
||||
{
|
||||
"content": b_txt,
|
||||
"_id": b_id,
|
||||
"$vector": b_emb,
|
||||
"metadata": b_md,
|
||||
}
|
||||
for b_txt, b_emb, b_id, b_md in zip(
|
||||
_texts,
|
||||
embedding_vectors,
|
||||
ids,
|
||||
metadatas,
|
||||
)
|
||||
]
|
||||
# make unique by id, keeping the last
|
||||
uniqued_documents_to_insert = _unique_list(
|
||||
documents_to_insert[::-1],
|
||||
lambda document: document["_id"],
|
||||
)[::-1]
|
||||
|
||||
all_ids = []
|
||||
|
||||
def _handle_batch(document_batch: List[DocDict]) -> List[str]:
|
||||
im_result = self.collection.insert_many(
|
||||
documents=document_batch,
|
||||
options={"ordered": False},
|
||||
partial_failures_allowed=True,
|
||||
)
|
||||
if "status" not in im_result:
|
||||
raise ValueError(
|
||||
f"API Exception while running bulk insertion: {str(im_result)}"
|
||||
)
|
||||
|
||||
batch_inserted = im_result["status"]["insertedIds"]
|
||||
# estimation of the preexisting documents that failed
|
||||
missed_inserted_ids = {
|
||||
document["_id"] for document in document_batch
|
||||
} - set(batch_inserted)
|
||||
errors = im_result.get("errors", [])
|
||||
# careful for other sources of error other than "doc already exists"
|
||||
num_errors = len(errors)
|
||||
unexpected_errors = any(
|
||||
error.get("errorCode") != "DOCUMENT_ALREADY_EXISTS" for error in errors
|
||||
)
|
||||
if num_errors != len(missed_inserted_ids) or unexpected_errors:
|
||||
raise ValueError(
|
||||
f"API Exception while running bulk insertion: {str(errors)}"
|
||||
)
|
||||
|
||||
# deal with the missing insertions as upserts
|
||||
missing_from_batch = [
|
||||
document
|
||||
for document in document_batch
|
||||
if document["_id"] in missed_inserted_ids
|
||||
]
|
||||
|
||||
def _handle_missing_document(missing_document: DocDict) -> str:
|
||||
replacement_result = self.collection.find_one_and_replace(
|
||||
filter={"_id": missing_document["_id"]},
|
||||
replacement=missing_document,
|
||||
)
|
||||
return replacement_result["data"]["document"]["_id"]
|
||||
|
||||
_u_max_workers = (
|
||||
overwrite_concurrency or self.bulk_insert_overwrite_concurrency
|
||||
)
|
||||
with ThreadPoolExecutor(max_workers=_u_max_workers) as tpe2:
|
||||
batch_replaced = list(
|
||||
tpe2.map(
|
||||
_handle_missing_document,
|
||||
missing_from_batch,
|
||||
)
|
||||
)
|
||||
|
||||
upsert_ids = batch_inserted + batch_replaced
|
||||
return upsert_ids
|
||||
|
||||
_b_max_workers = batch_concurrency or self.bulk_insert_batch_concurrency
|
||||
with ThreadPoolExecutor(max_workers=_b_max_workers) as tpe:
|
||||
all_ids_nested = tpe.map(
|
||||
_handle_batch,
|
||||
batch_iterate(
|
||||
batch_size or self.batch_size,
|
||||
uniqued_documents_to_insert,
|
||||
),
|
||||
)
|
||||
|
||||
all_ids = [iid for id_list in all_ids_nested for iid in id_list]
|
||||
|
||||
return all_ids
|
||||
|
||||
def similarity_search_with_score_id_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
) -> List[Tuple[Document, float, str]]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
embedding (str): Embedding to look up documents similar to.
|
||||
k (int): Number of Documents to return. Defaults to 4.
|
||||
Returns:
|
||||
List of (Document, score, id), the most similar to the query vector.
|
||||
"""
|
||||
metadata_parameter = self._filter_to_metadata(filter)
|
||||
#
|
||||
hits = list(
|
||||
self.collection.paginated_find(
|
||||
filter=metadata_parameter,
|
||||
sort={"$vector": embedding},
|
||||
options={"limit": k},
|
||||
projection={
|
||||
"_id": 1,
|
||||
"content": 1,
|
||||
"metadata": 1,
|
||||
"$similarity": 1,
|
||||
},
|
||||
)
|
||||
)
|
||||
#
|
||||
return [
|
||||
(
|
||||
Document(
|
||||
page_content=hit["content"],
|
||||
metadata=hit["metadata"],
|
||||
),
|
||||
hit["$similarity"],
|
||||
hit["_id"],
|
||||
)
|
||||
for hit in hits
|
||||
]
|
||||
|
||||
def similarity_search_with_score_id(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
) -> List[Tuple[Document, float, str]]:
|
||||
embedding_vector = self.embedding.embed_query(query)
|
||||
return self.similarity_search_with_score_id_by_vector(
|
||||
embedding=embedding_vector,
|
||||
k=k,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
embedding (str): Embedding to look up documents similar to.
|
||||
k (int): Number of Documents to return. Defaults to 4.
|
||||
Returns:
|
||||
List of (Document, score), the most similar to the query vector.
|
||||
"""
|
||||
return [
|
||||
(doc, score)
|
||||
for (doc, score, doc_id) in self.similarity_search_with_score_id_by_vector(
|
||||
embedding=embedding,
|
||||
k=k,
|
||||
filter=filter,
|
||||
)
|
||||
]
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
embedding_vector = self.embedding.embed_query(query)
|
||||
return self.similarity_search_by_vector(
|
||||
embedding_vector,
|
||||
k,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return [
|
||||
doc
|
||||
for doc, _ in self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
filter=filter,
|
||||
)
|
||||
]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
embedding_vector = self.embedding.embed_query(query)
|
||||
return self.similarity_search_with_score_by_vector(
|
||||
embedding_vector,
|
||||
k,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
metadata_parameter = self._filter_to_metadata(filter)
|
||||
|
||||
prefetch_hits = list(
|
||||
self.collection.paginated_find(
|
||||
filter=metadata_parameter,
|
||||
sort={"$vector": embedding},
|
||||
options={"limit": fetch_k},
|
||||
projection={
|
||||
"_id": 1,
|
||||
"content": 1,
|
||||
"metadata": 1,
|
||||
"$similarity": 1,
|
||||
"$vector": 1,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
mmr_chosen_indices = maximal_marginal_relevance(
|
||||
np.array(embedding, dtype=np.float32),
|
||||
[prefetch_hit["$vector"] for prefetch_hit in prefetch_hits],
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
mmr_hits = [
|
||||
prefetch_hit
|
||||
for prefetch_index, prefetch_hit in enumerate(prefetch_hits)
|
||||
if prefetch_index in mmr_chosen_indices
|
||||
]
|
||||
return [
|
||||
Document(
|
||||
page_content=hit["content"],
|
||||
metadata=hit["metadata"],
|
||||
)
|
||||
for hit in mmr_hits
|
||||
]
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
Args:
|
||||
query (str): Text to look up documents similar to.
|
||||
k (int = 4): Number of Documents to return.
|
||||
fetch_k (int = 20): Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult (float = 0.5): Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Optional.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
embedding_vector = self.embedding.embed_query(query)
|
||||
return self.max_marginal_relevance_search_by_vector(
|
||||
embedding_vector,
|
||||
k,
|
||||
fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[ADBVST],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> ADBVST:
|
||||
"""Create an Astra DB vectorstore from raw texts.
|
||||
|
||||
Args:
|
||||
texts (List[str]): the texts to insert.
|
||||
embedding (Embeddings): the embedding function to use in the store.
|
||||
metadatas (Optional[List[dict]]): metadata dicts for the texts.
|
||||
ids (Optional[List[str]]): ids to associate to the texts.
|
||||
*Additional arguments*: you can pass any argument that you would
|
||||
to 'add_texts' and/or to the 'AstraDB' class constructor
|
||||
(see these methods for details). These arguments will be
|
||||
routed to the respective methods as they are.
|
||||
|
||||
Returns:
|
||||
an `AstraDb` vectorstore.
|
||||
"""
|
||||
|
||||
known_kwargs = {
|
||||
"collection_name",
|
||||
"token",
|
||||
"api_endpoint",
|
||||
"astra_db_client",
|
||||
"namespace",
|
||||
"metric",
|
||||
"batch_size",
|
||||
"bulk_insert_batch_concurrency",
|
||||
"bulk_insert_overwrite_concurrency",
|
||||
"bulk_delete_concurrency",
|
||||
"batch_concurrency",
|
||||
"overwrite_concurrency",
|
||||
}
|
||||
if kwargs:
|
||||
unknown_kwargs = set(kwargs.keys()) - known_kwargs
|
||||
if unknown_kwargs:
|
||||
warnings.warn(
|
||||
"Method 'from_texts' of AstraDB vector store invoked with "
|
||||
f"unsupported arguments ({', '.join(sorted(unknown_kwargs))}), "
|
||||
"which will be ignored."
|
||||
)
|
||||
|
||||
collection_name: str = kwargs["collection_name"]
|
||||
token = kwargs.get("token")
|
||||
api_endpoint = kwargs.get("api_endpoint")
|
||||
astra_db_client = kwargs.get("astra_db_client")
|
||||
namespace = kwargs.get("namespace")
|
||||
metric = kwargs.get("metric")
|
||||
|
||||
astra_db_store = cls(
|
||||
embedding=embedding,
|
||||
collection_name=collection_name,
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
namespace=namespace,
|
||||
metric=metric,
|
||||
batch_size=kwargs.get("batch_size"),
|
||||
bulk_insert_batch_concurrency=kwargs.get("bulk_insert_batch_concurrency"),
|
||||
bulk_insert_overwrite_concurrency=kwargs.get(
|
||||
"bulk_insert_overwrite_concurrency"
|
||||
),
|
||||
bulk_delete_concurrency=kwargs.get("bulk_delete_concurrency"),
|
||||
)
|
||||
astra_db_store.add_texts(
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
batch_size=kwargs.get("batch_size"),
|
||||
batch_concurrency=kwargs.get("batch_concurrency"),
|
||||
overwrite_concurrency=kwargs.get("overwrite_concurrency"),
|
||||
)
|
||||
return astra_db_store
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls: Type[ADBVST],
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> ADBVST:
|
||||
"""Create an Astra DB vectorstore from a document list.
|
||||
|
||||
Utility method that defers to 'from_texts' (see that one).
|
||||
|
||||
Args: see 'from_texts', except here you have to supply 'documents'
|
||||
in place of 'texts' and 'metadatas'.
|
||||
|
||||
Returns:
|
||||
an `AstraDB` vectorstore.
|
||||
"""
|
||||
return super().from_documents(documents, embedding, **kwargs)
|
@@ -0,0 +1,468 @@
|
||||
"""
|
||||
Test of Astra DB vector store class `AstraDB`
|
||||
|
||||
Required to run this test:
|
||||
- a recent `astrapy` Python package available
|
||||
- an Astra DB instance;
|
||||
- the two environment variables set:
|
||||
export ASTRA_DB_API_ENDPOINT="https://<DB-ID>-us-east1.apps.astra.datastax.com"
|
||||
export ASTRA_DB_APPLICATION_TOKEN="AstraCS:........."
|
||||
- optionally this as well (otherwise defaults are used):
|
||||
export ASTRA_DB_KEYSPACE="my_keyspace"
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import Iterable, List
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import AstraDB
|
||||
|
||||
# Ad-hoc embedding classes:
|
||||
|
||||
|
||||
class SomeEmbeddings(Embeddings):
|
||||
"""
|
||||
Turn a sentence into an embedding vector in some way.
|
||||
Not important how. It is deterministic is all that counts.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int) -> None:
|
||||
self.dimension = dimension
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return [self.embed_query(txt) for txt in texts]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
unnormed0 = [ord(c) for c in text[: self.dimension]]
|
||||
unnormed = (unnormed0 + [1] + [0] * (self.dimension - 1 - len(unnormed0)))[
|
||||
: self.dimension
|
||||
]
|
||||
norm = sum(x * x for x in unnormed) ** 0.5
|
||||
normed = [x / norm for x in unnormed]
|
||||
return normed
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return self.embed_query(text)
|
||||
|
||||
|
||||
class ParserEmbeddings(Embeddings):
|
||||
"""
|
||||
Parse input texts: if they are json for a List[float], fine.
|
||||
Otherwise, return all zeros and call it a day.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int) -> None:
|
||||
self.dimension = dimension
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return [self.embed_query(txt) for txt in texts]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
try:
|
||||
vals = json.loads(text)
|
||||
assert len(vals) == self.dimension
|
||||
return vals
|
||||
except Exception:
|
||||
print(f'[ParserEmbeddings] Returning a moot vector for "{text}"')
|
||||
return [0.0] * self.dimension
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return self.embed_query(text)
|
||||
|
||||
|
||||
def _has_env_vars() -> bool:
|
||||
return all(
|
||||
[
|
||||
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
|
||||
"ASTRA_DB_API_ENDPOINT" in os.environ,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def store_someemb() -> Iterable[AstraDB]:
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
v_store = AstraDB(
|
||||
embedding=emb,
|
||||
collection_name="lc_test_s",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
yield v_store
|
||||
v_store.delete_collection()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def store_parseremb() -> Iterable[AstraDB]:
|
||||
emb = ParserEmbeddings(dimension=2)
|
||||
v_store = AstraDB(
|
||||
embedding=emb,
|
||||
collection_name="lc_test_p",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
yield v_store
|
||||
v_store.delete_collection()
|
||||
|
||||
|
||||
@pytest.mark.requires("astrapy")
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
class TestAstraDB:
|
||||
def test_astradb_vectorstore_create_delete(self) -> None:
|
||||
"""Create and delete."""
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
# creation by passing the connection secrets
|
||||
v_store = AstraDB(
|
||||
embedding=emb,
|
||||
collection_name="lc_test_1",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
v_store.delete_collection()
|
||||
# Creation by passing a ready-made astrapy client:
|
||||
from astrapy.db import AstraDB as LibAstraDB
|
||||
|
||||
astra_db_client = LibAstraDB(
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
v_store_2 = AstraDB(
|
||||
embedding=emb,
|
||||
collection_name="lc_test_2",
|
||||
astra_db_client=astra_db_client,
|
||||
)
|
||||
v_store_2.delete_collection()
|
||||
|
||||
def test_astradb_vectorstore_from_x(self) -> None:
|
||||
"""from_texts and from_documents methods."""
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
# from_texts
|
||||
v_store = AstraDB.from_texts(
|
||||
texts=["Hi", "Ho"],
|
||||
embedding=emb,
|
||||
collection_name="lc_test_ft",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
assert v_store.similarity_search("Ho", k=1)[0].page_content == "Ho"
|
||||
v_store.delete_collection()
|
||||
|
||||
# from_texts
|
||||
v_store_2 = AstraDB.from_documents(
|
||||
[
|
||||
Document(page_content="Hee"),
|
||||
Document(page_content="Hoi"),
|
||||
],
|
||||
embedding=emb,
|
||||
collection_name="lc_test_fd",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
assert v_store_2.similarity_search("Hoi", k=1)[0].page_content == "Hoi"
|
||||
# manual collection delete
|
||||
v_store_2.delete_collection()
|
||||
|
||||
def test_astradb_vectorstore_crud(self, store_someemb: AstraDB) -> None:
|
||||
"""Basic add/delete/update behaviour."""
|
||||
res0 = store_someemb.similarity_search("Abc", k=2)
|
||||
assert res0 == []
|
||||
# write and check again
|
||||
store_someemb.add_texts(
|
||||
texts=["aa", "bb", "cc"],
|
||||
metadatas=[
|
||||
{"k": "a", "ord": 0},
|
||||
{"k": "b", "ord": 1},
|
||||
{"k": "c", "ord": 2},
|
||||
],
|
||||
ids=["a", "b", "c"],
|
||||
)
|
||||
res1 = store_someemb.similarity_search("Abc", k=5)
|
||||
assert {doc.page_content for doc in res1} == {"aa", "bb", "cc"}
|
||||
# partial overwrite and count total entries
|
||||
store_someemb.add_texts(
|
||||
texts=["cc", "dd"],
|
||||
metadatas=[
|
||||
{"k": "c_new", "ord": 102},
|
||||
{"k": "d_new", "ord": 103},
|
||||
],
|
||||
ids=["c", "d"],
|
||||
)
|
||||
res2 = store_someemb.similarity_search("Abc", k=10)
|
||||
assert len(res2) == 4
|
||||
# pick one that was just updated and check its metadata
|
||||
res3 = store_someemb.similarity_search_with_score_id("cc", k=1)
|
||||
doc3, score3, id3 = res3[0]
|
||||
assert doc3.page_content == "cc"
|
||||
assert doc3.metadata == {"k": "c_new", "ord": 102}
|
||||
assert score3 > 0.999 # leaving some leeway for approximations...
|
||||
assert id3 == "c"
|
||||
# delete and count again
|
||||
del1_res = store_someemb.delete(["b"])
|
||||
assert del1_res is True
|
||||
del2_res = store_someemb.delete(["a", "c", "Z!"])
|
||||
assert del2_res is False # a non-existing ID was supplied
|
||||
assert len(store_someemb.similarity_search("xy", k=10)) == 1
|
||||
# clear store
|
||||
store_someemb.clear()
|
||||
assert store_someemb.similarity_search("Abc", k=2) == []
|
||||
# add_documents with "ids" arg passthrough
|
||||
store_someemb.add_documents(
|
||||
[
|
||||
Document(page_content="vv", metadata={"k": "v", "ord": 204}),
|
||||
Document(page_content="ww", metadata={"k": "w", "ord": 205}),
|
||||
],
|
||||
ids=["v", "w"],
|
||||
)
|
||||
assert len(store_someemb.similarity_search("xy", k=10)) == 2
|
||||
res4 = store_someemb.similarity_search("ww", k=1)
|
||||
assert res4[0].metadata["ord"] == 205
|
||||
|
||||
def test_astradb_vectorstore_mmr(self, store_parseremb: AstraDB) -> None:
|
||||
"""
|
||||
MMR testing. We work on the unit circle with angle multiples
|
||||
of 2*pi/20 and prepare a store with known vectors for a controlled
|
||||
MMR outcome.
|
||||
"""
|
||||
|
||||
def _v_from_i(i: int, N: int) -> str:
|
||||
angle = 2 * math.pi * i / N
|
||||
vector = [math.cos(angle), math.sin(angle)]
|
||||
return json.dumps(vector)
|
||||
|
||||
i_vals = [0, 4, 5, 13]
|
||||
N_val = 20
|
||||
store_parseremb.add_texts(
|
||||
[_v_from_i(i, N_val) for i in i_vals], metadatas=[{"i": i} for i in i_vals]
|
||||
)
|
||||
res1 = store_parseremb.max_marginal_relevance_search(
|
||||
_v_from_i(3, N_val),
|
||||
k=2,
|
||||
fetch_k=3,
|
||||
)
|
||||
res_i_vals = {doc.metadata["i"] for doc in res1}
|
||||
assert res_i_vals == {0, 4}
|
||||
|
||||
def test_astradb_vectorstore_metadata(self, store_someemb: AstraDB) -> None:
|
||||
"""Metadata filtering."""
|
||||
store_someemb.add_documents(
|
||||
[
|
||||
Document(
|
||||
page_content="q",
|
||||
metadata={"ord": ord("q"), "group": "consonant"},
|
||||
),
|
||||
Document(
|
||||
page_content="w",
|
||||
metadata={"ord": ord("w"), "group": "consonant"},
|
||||
),
|
||||
Document(
|
||||
page_content="r",
|
||||
metadata={"ord": ord("r"), "group": "consonant"},
|
||||
),
|
||||
Document(
|
||||
page_content="e",
|
||||
metadata={"ord": ord("e"), "group": "vowel"},
|
||||
),
|
||||
Document(
|
||||
page_content="i",
|
||||
metadata={"ord": ord("i"), "group": "vowel"},
|
||||
),
|
||||
Document(
|
||||
page_content="o",
|
||||
metadata={"ord": ord("o"), "group": "vowel"},
|
||||
),
|
||||
]
|
||||
)
|
||||
# no filters
|
||||
res0 = store_someemb.similarity_search("x", k=10)
|
||||
assert {doc.page_content for doc in res0} == set("qwreio")
|
||||
# single filter
|
||||
res1 = store_someemb.similarity_search(
|
||||
"x",
|
||||
k=10,
|
||||
filter={"group": "vowel"},
|
||||
)
|
||||
assert {doc.page_content for doc in res1} == set("eio")
|
||||
# multiple filters
|
||||
res2 = store_someemb.similarity_search(
|
||||
"x",
|
||||
k=10,
|
||||
filter={"group": "consonant", "ord": ord("q")},
|
||||
)
|
||||
assert {doc.page_content for doc in res2} == set("q")
|
||||
# excessive filters
|
||||
res3 = store_someemb.similarity_search(
|
||||
"x",
|
||||
k=10,
|
||||
filter={"group": "consonant", "ord": ord("q"), "case": "upper"},
|
||||
)
|
||||
assert res3 == []
|
||||
|
||||
def test_astradb_vectorstore_similarity_scale(
|
||||
self, store_parseremb: AstraDB
|
||||
) -> None:
|
||||
"""Scale of the similarity scores."""
|
||||
store_parseremb.add_texts(
|
||||
texts=[
|
||||
json.dumps([1, 1]),
|
||||
json.dumps([-1, -1]),
|
||||
],
|
||||
ids=["near", "far"],
|
||||
)
|
||||
res1 = store_parseremb.similarity_search_with_score(
|
||||
json.dumps([0.5, 0.5]),
|
||||
k=2,
|
||||
)
|
||||
scores = [sco for _, sco in res1]
|
||||
sco_near, sco_far = scores
|
||||
assert abs(1 - sco_near) < 0.001 and abs(sco_far) < 0.001
|
||||
|
||||
def test_astradb_vectorstore_massive_delete(self, store_someemb: AstraDB) -> None:
|
||||
"""Larger-scale bulk deletes."""
|
||||
M = 50
|
||||
texts = [str(i + 1 / 7.0) for i in range(2 * M)]
|
||||
ids0 = ["doc_%i" % i for i in range(M)]
|
||||
ids1 = ["doc_%i" % (i + M) for i in range(M)]
|
||||
ids = ids0 + ids1
|
||||
store_someemb.add_texts(texts=texts, ids=ids)
|
||||
# deleting a bunch of these
|
||||
del_res0 = store_someemb.delete(ids0)
|
||||
assert del_res0 is True
|
||||
# deleting the rest plus a fake one
|
||||
del_res1 = store_someemb.delete(ids1 + ["ghost!"])
|
||||
assert del_res1 is False # not *all* ids could be deleted...
|
||||
# nothing left
|
||||
assert store_someemb.similarity_search("x", k=2 * M) == []
|
||||
|
||||
def test_astradb_vectorstore_drop(self) -> None:
|
||||
"""behaviour of 'delete_collection'."""
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
v_store = AstraDB(
|
||||
embedding=emb,
|
||||
collection_name="lc_test_d",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
v_store.add_texts(["huh"])
|
||||
assert len(v_store.similarity_search("hah", k=10)) == 1
|
||||
# another instance pointing to the same collection on DB
|
||||
v_store_kenny = AstraDB(
|
||||
embedding=emb,
|
||||
collection_name="lc_test_d",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
v_store_kenny.delete_collection()
|
||||
# dropped on DB, but 'v_store' should have no clue:
|
||||
with pytest.raises(ValueError):
|
||||
_ = v_store.similarity_search("hah", k=10)
|
||||
|
||||
def test_astradb_vectorstore_custom_params(self) -> None:
|
||||
"""Custom batch size and concurrency params."""
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
v_store = AstraDB(
|
||||
embedding=emb,
|
||||
collection_name="lc_test_c",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
batch_size=17,
|
||||
bulk_insert_batch_concurrency=13,
|
||||
bulk_insert_overwrite_concurrency=7,
|
||||
bulk_delete_concurrency=19,
|
||||
)
|
||||
# add_texts
|
||||
N = 50
|
||||
texts = [str(i + 1 / 7.0) for i in range(N)]
|
||||
ids = ["doc_%i" % i for i in range(N)]
|
||||
v_store.add_texts(texts=texts, ids=ids)
|
||||
v_store.add_texts(
|
||||
texts=texts,
|
||||
ids=ids,
|
||||
batch_size=19,
|
||||
batch_concurrency=7,
|
||||
overwrite_concurrency=13,
|
||||
)
|
||||
#
|
||||
_ = v_store.delete(ids[: N // 2])
|
||||
_ = v_store.delete(ids[N // 2 :], concurrency=23)
|
||||
#
|
||||
v_store.delete_collection()
|
||||
|
||||
def test_astradb_vectorstore_metrics(self) -> None:
|
||||
"""
|
||||
Different choices of similarity metric.
|
||||
Both stores (with "cosine" and "euclidea" metrics) contain these two:
|
||||
- a vector slightly rotated w.r.t query vector
|
||||
- a vector which is a long multiple of query vector
|
||||
so, which one is "the closest one" depends on the metric.
|
||||
"""
|
||||
emb = ParserEmbeddings(dimension=2)
|
||||
isq2 = 0.5**0.5
|
||||
isa = 0.7
|
||||
isb = (1.0 - isa * isa) ** 0.5
|
||||
texts = [
|
||||
json.dumps([isa, isb]),
|
||||
json.dumps([10 * isq2, 10 * isq2]),
|
||||
]
|
||||
ids = [
|
||||
"rotated",
|
||||
"scaled",
|
||||
]
|
||||
query_text = json.dumps([isq2, isq2])
|
||||
# creation, population, query - cosine
|
||||
vstore_cos = AstraDB(
|
||||
embedding=emb,
|
||||
collection_name="lc_test_m_c",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
metric="cosine",
|
||||
)
|
||||
vstore_cos.add_texts(
|
||||
texts=texts,
|
||||
ids=ids,
|
||||
)
|
||||
_, _, id_from_cos = vstore_cos.similarity_search_with_score_id(
|
||||
query_text,
|
||||
k=1,
|
||||
)[0]
|
||||
assert id_from_cos == "scaled"
|
||||
vstore_cos.delete_collection()
|
||||
# creation, population, query - euclidean
|
||||
vstore_euc = AstraDB(
|
||||
embedding=emb,
|
||||
collection_name="lc_test_m_e",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
metric="euclidean",
|
||||
)
|
||||
vstore_euc.add_texts(
|
||||
texts=texts,
|
||||
ids=ids,
|
||||
)
|
||||
_, _, id_from_euc = vstore_euc.similarity_search_with_score_id(
|
||||
query_text,
|
||||
k=1,
|
||||
)[0]
|
||||
assert id_from_euc == "rotated"
|
||||
vstore_euc.delete_collection()
|
@@ -1125,6 +1125,7 @@ def test_compatible_vectorstore_documentation() -> None:
|
||||
# These are mentioned in the indexing.ipynb documentation
|
||||
documented = {
|
||||
"AnalyticDB",
|
||||
"AstraDB",
|
||||
"AzureCosmosDBVectorSearch",
|
||||
"AwaDB",
|
||||
"Bagel",
|
||||
|
@@ -11,6 +11,7 @@ _EXPECTED = [
|
||||
"AzureSearch",
|
||||
"Bagel",
|
||||
"Cassandra",
|
||||
"AstraDB",
|
||||
"Chroma",
|
||||
"Clarifai",
|
||||
"Clickhouse",
|
||||
|
Reference in New Issue
Block a user