mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 06:14:37 +00:00
community[minor]: Add async methods to CassandraVectorStore (#20602)
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
06d18c106d
commit
c909ae0152
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -22,3 +23,9 @@ async def wrapped_response_future(
|
||||
|
||||
response_future.add_callbacks(success_handler, error_handler)
|
||||
return await asyncio_future
|
||||
|
||||
|
||||
class SetupMode(Enum):
|
||||
SYNC = 1
|
||||
ASYNC = 2
|
||||
OFF = 3
|
||||
|
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
import uuid
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
@ -24,6 +26,7 @@ from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.utilities.cassandra import SetupMode
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
CVST = TypeVar("CVST", bound="Cassandra")
|
||||
@ -70,6 +73,13 @@ class Cassandra(VectorStore):
|
||||
)
|
||||
return self._embedding_dimension
|
||||
|
||||
async def _aget_embedding_dimension(self) -> int:
|
||||
if self._embedding_dimension is None:
|
||||
self._embedding_dimension = len(
|
||||
await self.embedding.aembed_query("This is a sample sentence.")
|
||||
)
|
||||
return self._embedding_dimension
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding: Embeddings,
|
||||
@ -79,6 +89,7 @@ class Cassandra(VectorStore):
|
||||
ttl_seconds: Optional[int] = None,
|
||||
*,
|
||||
body_index_options: Optional[List[Tuple[str, Any]]] = None,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
) -> None:
|
||||
try:
|
||||
from cassio.table import MetadataVectorCassandraTable
|
||||
@ -96,17 +107,26 @@ class Cassandra(VectorStore):
|
||||
#
|
||||
self._embedding_dimension = None
|
||||
#
|
||||
kwargs = {}
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if body_index_options is not None:
|
||||
kwargs["body_index_options"] = body_index_options
|
||||
if setup_mode == SetupMode.ASYNC:
|
||||
kwargs["async_setup"] = True
|
||||
|
||||
embedding_dimension: Union[int, Awaitable[int], None] = None
|
||||
if setup_mode == SetupMode.ASYNC:
|
||||
embedding_dimension = self._aget_embedding_dimension()
|
||||
elif setup_mode == SetupMode.SYNC:
|
||||
embedding_dimension = self._get_embedding_dimension()
|
||||
|
||||
self.table = MetadataVectorCassandraTable(
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
table=table_name,
|
||||
vector_dimension=self._get_embedding_dimension(),
|
||||
vector_dimension=embedding_dimension,
|
||||
metadata_indexing="all",
|
||||
primary_key_type="TEXT",
|
||||
skip_provisioning=setup_mode == SetupMode.OFF,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -129,17 +149,30 @@ class Cassandra(VectorStore):
|
||||
"""
|
||||
self.clear()
|
||||
|
||||
async def adelete_collection(self) -> None:
|
||||
"""
|
||||
Just an alias for `aclear`
|
||||
(to better align with other VectorStore implementations).
|
||||
"""
|
||||
await self.aclear()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Empty the table."""
|
||||
self.table.clear()
|
||||
|
||||
async def aclear(self) -> None:
|
||||
"""Empty the table."""
|
||||
await self.table.aclear()
|
||||
|
||||
def delete_by_document_id(self, document_id: str) -> None:
|
||||
return self.table.delete(row_id=document_id)
|
||||
|
||||
async def adelete_by_document_id(self, document_id: str) -> None:
|
||||
return await self.table.adelete(row_id=document_id)
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
||||
"""Delete by vector IDs.
|
||||
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
|
||||
@ -155,6 +188,26 @@ class Cassandra(VectorStore):
|
||||
self.delete_by_document_id(document_id)
|
||||
return True
|
||||
|
||||
async def adelete(
|
||||
self, ids: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Optional[bool]:
|
||||
"""Delete by vector IDs.
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
|
||||
Returns:
|
||||
Optional[bool]: True if deletion is successful,
|
||||
False otherwise, None if not implemented.
|
||||
"""
|
||||
|
||||
if ids is None:
|
||||
raise ValueError("No ids provided to delete.")
|
||||
|
||||
for document_id in ids:
|
||||
await self.adelete_by_document_id(document_id)
|
||||
return True
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
@ -176,16 +229,12 @@ class Cassandra(VectorStore):
|
||||
Returns:
|
||||
List[str]: List of IDs of the added texts.
|
||||
"""
|
||||
_texts = list(texts) # lest it be a generator or something
|
||||
if ids is None:
|
||||
ids = [uuid.uuid4().hex for _ in _texts]
|
||||
if metadatas is None:
|
||||
metadatas = [{} for _ in _texts]
|
||||
#
|
||||
_texts = list(texts)
|
||||
ids = ids or [uuid.uuid4().hex for _ in _texts]
|
||||
metadatas = metadatas or [{}] * len(_texts)
|
||||
ttl_seconds = ttl_seconds or self.ttl_seconds
|
||||
#
|
||||
embedding_vectors = self.embedding.embed_documents(_texts)
|
||||
#
|
||||
|
||||
for i in range(0, len(_texts), batch_size):
|
||||
batch_texts = _texts[i : i + batch_size]
|
||||
batch_embedding_vectors = embedding_vectors[i : i + batch_size]
|
||||
@ -208,6 +257,77 @@ class Cassandra(VectorStore):
|
||||
future.result()
|
||||
return ids
|
||||
|
||||
async def aadd_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
concurrency: int = 16,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
texts: Texts to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas.
|
||||
ids: Optional list of IDs.
|
||||
concurrency: Number of concurrent queries to the database.
|
||||
Defaults to 16.
|
||||
ttl_seconds: Optional time-to-live for the added texts.
|
||||
|
||||
Returns:
|
||||
List[str]: List of IDs of the added texts.
|
||||
"""
|
||||
_texts = list(texts)
|
||||
ids = ids or [uuid.uuid4().hex for _ in _texts]
|
||||
_metadatas: List[dict] = metadatas or [{}] * len(_texts)
|
||||
ttl_seconds = ttl_seconds or self.ttl_seconds
|
||||
embedding_vectors = await self.embedding.aembed_documents(_texts)
|
||||
|
||||
sem = asyncio.Semaphore(concurrency)
|
||||
|
||||
async def send_concurrently(
|
||||
row_id: str, text: str, embedding_vector: List[float], metadata: dict
|
||||
) -> None:
|
||||
async with sem:
|
||||
await self.table.aput(
|
||||
row_id=row_id,
|
||||
body_blob=text,
|
||||
vector=embedding_vector,
|
||||
metadata=metadata or {},
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
|
||||
for i in range(0, len(_texts)):
|
||||
tasks = [
|
||||
asyncio.create_task(
|
||||
send_concurrently(
|
||||
ids[i], _texts[i], embedding_vectors[i], _metadatas[i]
|
||||
)
|
||||
)
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
return ids
|
||||
|
||||
@staticmethod
|
||||
def _search_to_documents(
|
||||
hits: Iterable[Dict[str, Any]],
|
||||
) -> List[Tuple[Document, float, str]]:
|
||||
# We stick to 'cos' distance as it can be normalized on a 0-1 axis
|
||||
# (1=most relevant), as required by this class' contract.
|
||||
return [
|
||||
(
|
||||
Document(
|
||||
page_content=hit["body_blob"],
|
||||
metadata=hit["metadata"],
|
||||
),
|
||||
0.5 + 0.5 * hit["distance"],
|
||||
hit["row_id"],
|
||||
)
|
||||
for hit in hits
|
||||
]
|
||||
|
||||
# id-returning search facilities
|
||||
def similarity_search_with_score_id_by_vector(
|
||||
self,
|
||||
@ -232,26 +352,46 @@ class Cassandra(VectorStore):
|
||||
kwargs["metadata"] = filter
|
||||
if body_search is not None:
|
||||
kwargs["body_search"] = body_search
|
||||
#
|
||||
|
||||
hits = self.table.metric_ann_search(
|
||||
vector=embedding,
|
||||
n=k,
|
||||
metric="cos",
|
||||
**kwargs,
|
||||
)
|
||||
# We stick to 'cos' distance as it can be normalized on a 0-1 axis
|
||||
# (1=most relevant), as required by this class' contract.
|
||||
return [
|
||||
(
|
||||
Document(
|
||||
page_content=hit["body_blob"],
|
||||
metadata=hit["metadata"],
|
||||
),
|
||||
0.5 + 0.5 * hit["distance"],
|
||||
hit["row_id"],
|
||||
)
|
||||
for hit in hits
|
||||
]
|
||||
return self._search_to_documents(hits)
|
||||
|
||||
async def asimilarity_search_with_score_id_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
body_search: Optional[Union[str, List[str]]] = None,
|
||||
) -> List[Tuple[Document, float, str]]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
embedding (str): Embedding to look up documents similar to.
|
||||
k (int): Number of Documents to return. Defaults to 4.
|
||||
filter: Filter on the metadata to apply.
|
||||
body_search: Document textual search terms to apply.
|
||||
Only supported by Astra DB at the moment.
|
||||
Returns:
|
||||
List of (Document, score, id), the most similar to the query vector.
|
||||
"""
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if filter is not None:
|
||||
kwargs["metadata"] = filter
|
||||
if body_search is not None:
|
||||
kwargs["body_search"] = body_search
|
||||
|
||||
hits = await self.table.ametric_ann_search(
|
||||
vector=embedding,
|
||||
n=k,
|
||||
metric="cos",
|
||||
**kwargs,
|
||||
)
|
||||
return self._search_to_documents(hits)
|
||||
|
||||
def similarity_search_with_score_id(
|
||||
self,
|
||||
@ -268,6 +408,21 @@ class Cassandra(VectorStore):
|
||||
body_search=body_search,
|
||||
)
|
||||
|
||||
async def asimilarity_search_with_score_id(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
body_search: Optional[Union[str, List[str]]] = None,
|
||||
) -> List[Tuple[Document, float, str]]:
|
||||
embedding_vector = await self.embedding.aembed_query(query)
|
||||
return await self.asimilarity_search_with_score_id_by_vector(
|
||||
embedding=embedding_vector,
|
||||
k=k,
|
||||
filter=filter,
|
||||
body_search=body_search,
|
||||
)
|
||||
|
||||
# id-unaware search facilities
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
@ -297,6 +452,38 @@ class Cassandra(VectorStore):
|
||||
)
|
||||
]
|
||||
|
||||
async def asimilarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
body_search: Optional[Union[str, List[str]]] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
embedding (str): Embedding to look up documents similar to.
|
||||
k (int): Number of Documents to return. Defaults to 4.
|
||||
filter: Filter on the metadata to apply.
|
||||
body_search: Document textual search terms to apply.
|
||||
Only supported by Astra DB at the moment.
|
||||
Returns:
|
||||
List of (Document, score), the most similar to the query vector.
|
||||
"""
|
||||
return [
|
||||
(doc, score)
|
||||
for (
|
||||
doc,
|
||||
score,
|
||||
_,
|
||||
) in await self.asimilarity_search_with_score_id_by_vector(
|
||||
embedding=embedding,
|
||||
k=k,
|
||||
filter=filter,
|
||||
body_search=body_search,
|
||||
)
|
||||
]
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
@ -313,6 +500,22 @@ class Cassandra(VectorStore):
|
||||
body_search=body_search,
|
||||
)
|
||||
|
||||
async def asimilarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
body_search: Optional[Union[str, List[str]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
embedding_vector = await self.embedding.aembed_query(query)
|
||||
return await self.asimilarity_search_by_vector(
|
||||
embedding_vector,
|
||||
k,
|
||||
filter=filter,
|
||||
body_search=body_search,
|
||||
)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
@ -331,6 +534,24 @@ class Cassandra(VectorStore):
|
||||
)
|
||||
]
|
||||
|
||||
async def asimilarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
body_search: Optional[Union[str, List[str]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return [
|
||||
doc
|
||||
for doc, _ in await self.asimilarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
filter=filter,
|
||||
body_search=body_search,
|
||||
)
|
||||
]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
@ -346,6 +567,48 @@ class Cassandra(VectorStore):
|
||||
body_search=body_search,
|
||||
)
|
||||
|
||||
async def asimilarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
body_search: Optional[Union[str, List[str]]] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
embedding_vector = await self.embedding.aembed_query(query)
|
||||
return await self.asimilarity_search_with_score_by_vector(
|
||||
embedding_vector,
|
||||
k,
|
||||
filter=filter,
|
||||
body_search=body_search,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _mmr_search_to_documents(
|
||||
prefetch_hits: List[Dict[str, Any]],
|
||||
embedding: List[float],
|
||||
k: int,
|
||||
lambda_mult: float,
|
||||
) -> List[Document]:
|
||||
# let the mmr utility pick the *indices* in the above array
|
||||
mmr_chosen_indices = maximal_marginal_relevance(
|
||||
np.array(embedding, dtype=np.float32),
|
||||
[pf_hit["vector"] for pf_hit in prefetch_hits],
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
mmr_hits = [
|
||||
pf_hit
|
||||
for pf_index, pf_hit in enumerate(prefetch_hits)
|
||||
if pf_index in mmr_chosen_indices
|
||||
]
|
||||
return [
|
||||
Document(
|
||||
page_content=hit["body_blob"],
|
||||
metadata=hit["metadata"],
|
||||
)
|
||||
for hit in mmr_hits
|
||||
]
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
@ -388,25 +651,51 @@ class Cassandra(VectorStore):
|
||||
**_kwargs,
|
||||
)
|
||||
)
|
||||
# let the mmr utility pick the *indices* in the above array
|
||||
mmr_chosen_indices = maximal_marginal_relevance(
|
||||
np.array(embedding, dtype=np.float32),
|
||||
[pf_hit["vector"] for pf_hit in prefetch_hits],
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
mmr_hits = [
|
||||
pf_hit
|
||||
for pf_index, pf_hit in enumerate(prefetch_hits)
|
||||
if pf_index in mmr_chosen_indices
|
||||
]
|
||||
return [
|
||||
Document(
|
||||
page_content=hit["body_blob"],
|
||||
metadata=hit["metadata"],
|
||||
return self._mmr_search_to_documents(prefetch_hits, embedding, k, lambda_mult)
|
||||
|
||||
async def amax_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
body_search: Optional[Union[str, List[str]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
Defaults to 20.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding to maximum
|
||||
diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter: Filter on the metadata to apply.
|
||||
body_search: Document textual search terms to apply.
|
||||
Only supported by Astra DB at the moment.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
_kwargs: Dict[str, Any] = {}
|
||||
if filter is not None:
|
||||
_kwargs["metadata"] = filter
|
||||
if body_search is not None:
|
||||
_kwargs["body_search"] = body_search
|
||||
|
||||
prefetch_hits = list(
|
||||
await self.table.ametric_ann_search(
|
||||
vector=embedding,
|
||||
n=fetch_k,
|
||||
metric="cos",
|
||||
**_kwargs,
|
||||
)
|
||||
for hit in mmr_hits
|
||||
]
|
||||
)
|
||||
return self._mmr_search_to_documents(prefetch_hits, embedding, k, lambda_mult)
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
@ -446,6 +735,43 @@ class Cassandra(VectorStore):
|
||||
body_search=body_search,
|
||||
)
|
||||
|
||||
async def amax_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
body_search: Optional[Union[str, List[str]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding to maximum
|
||||
diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter: Filter on the metadata to apply.
|
||||
body_search: Document textual search terms to apply.
|
||||
Only supported by Astra DB at the moment.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
embedding_vector = await self.embedding.aembed_query(query)
|
||||
return await self.amax_marginal_relevance_search_by_vector(
|
||||
embedding_vector,
|
||||
k,
|
||||
fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
body_search=body_search,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[CVST],
|
||||
@ -500,6 +826,61 @@ class Cassandra(VectorStore):
|
||||
)
|
||||
return store
|
||||
|
||||
@classmethod
|
||||
async def afrom_texts(
|
||||
cls: Type[CVST],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
*,
|
||||
session: Session = _NOT_SET,
|
||||
keyspace: str = "",
|
||||
table_name: str = "",
|
||||
ids: Optional[List[str]] = None,
|
||||
concurrency: int = 16,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
body_index_options: Optional[List[Tuple[str, Any]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> CVST:
|
||||
"""Create a Cassandra vectorstore from raw texts.
|
||||
|
||||
Args:
|
||||
texts: Texts to add to the vectorstore.
|
||||
embedding: Embedding function to use.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
session: Cassandra driver session (required).
|
||||
keyspace: Cassandra key space (required).
|
||||
table_name: Cassandra table (required).
|
||||
ids: Optional list of IDs associated with the texts.
|
||||
concurrency: Number of concurrent queries to send to the database.
|
||||
Defaults to 16.
|
||||
ttl_seconds: Optional time-to-live for the added texts.
|
||||
body_index_options: Optional options used to create the body index.
|
||||
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
|
||||
|
||||
Returns:
|
||||
a Cassandra vectorstore.
|
||||
"""
|
||||
if session is _NOT_SET:
|
||||
raise ValueError("session parameter is required")
|
||||
if not keyspace:
|
||||
raise ValueError("keyspace parameter is required")
|
||||
if not table_name:
|
||||
raise ValueError("table_name parameter is required")
|
||||
store = cls(
|
||||
embedding=embedding,
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
table_name=table_name,
|
||||
ttl_seconds=ttl_seconds,
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
body_index_options=body_index_options,
|
||||
)
|
||||
await store.aadd_texts(
|
||||
texts=texts, metadatas=metadatas, ids=ids, concurrency=concurrency
|
||||
)
|
||||
return store
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls: Type[CVST],
|
||||
@ -548,3 +929,52 @@ class Cassandra(VectorStore):
|
||||
body_index_options=body_index_options,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def afrom_documents(
|
||||
cls: Type[CVST],
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
*,
|
||||
session: Session = _NOT_SET,
|
||||
keyspace: str = "",
|
||||
table_name: str = "",
|
||||
ids: Optional[List[str]] = None,
|
||||
concurrency: int = 16,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
body_index_options: Optional[List[Tuple[str, Any]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> CVST:
|
||||
"""Create a Cassandra vectorstore from a document list.
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
embedding: Embedding function to use.
|
||||
session: Cassandra driver session (required).
|
||||
keyspace: Cassandra key space (required).
|
||||
table_name: Cassandra table (required).
|
||||
ids: Optional list of IDs associated with the documents.
|
||||
concurrency: Number of concurrent queries to send to the database.
|
||||
Defaults to 16.
|
||||
ttl_seconds: Optional time-to-live for the added documents.
|
||||
body_index_options: Optional options used to create the body index.
|
||||
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
|
||||
|
||||
Returns:
|
||||
a Cassandra vectorstore.
|
||||
"""
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
return await cls.afrom_texts(
|
||||
texts=texts,
|
||||
embedding=embedding,
|
||||
metadatas=metadatas,
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
table_name=table_name,
|
||||
ids=ids,
|
||||
concurrency=concurrency,
|
||||
ttl_seconds=ttl_seconds,
|
||||
body_index_options=body_index_options,
|
||||
**kwargs,
|
||||
)
|
||||
|
28
libs/community/poetry.lock
generated
28
libs/community/poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aenum"
|
||||
@ -738,19 +738,19 @@ graph = ["gremlinpython (==3.4.6)"]
|
||||
|
||||
[[package]]
|
||||
name = "cassio"
|
||||
version = "0.1.5"
|
||||
version = "0.1.6"
|
||||
description = "A framework-agnostic Python library to seamlessly integrate Apache Cassandra(R) with ML/LLM/genAI workloads."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
python-versions = "<4.0,>=3.8"
|
||||
files = [
|
||||
{file = "cassio-0.1.5-py3-none-any.whl", hash = "sha256:cf1d11f255c040bc0aede4963ca020840133377aa54f7f15d2f819d6553d52ce"},
|
||||
{file = "cassio-0.1.5.tar.gz", hash = "sha256:88c50c34d46a1bfffca1e0b600318a6efef45e6c18a56ddabe208cbede8dcc27"},
|
||||
{file = "cassio-0.1.6-py3-none-any.whl", hash = "sha256:2ab767da43acdd850b2fb0eead7f0fd9cbb2884bb3864c6b0721dd589cbfe23a"},
|
||||
{file = "cassio-0.1.6.tar.gz", hash = "sha256:338ed89bd3dfdd7225b72ae70af2d7e058eb30582814b9f146a70f84a8d345f7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cassandra-driver = ">=3.28.0"
|
||||
cassandra-driver = ">=3.28.0,<4.0.0"
|
||||
numpy = ">=1.0"
|
||||
requests = ">=2"
|
||||
requests = ">=2.31.0,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "cerberus"
|
||||
@ -3204,6 +3204,7 @@ files = [
|
||||
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"},
|
||||
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"},
|
||||
{file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"},
|
||||
{file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"},
|
||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"},
|
||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"},
|
||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"},
|
||||
@ -5528,8 +5529,6 @@ files = [
|
||||
{file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"},
|
||||
{file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"},
|
||||
{file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"},
|
||||
{file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"},
|
||||
{file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"},
|
||||
{file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"},
|
||||
{file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"},
|
||||
{file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"},
|
||||
@ -5572,7 +5571,6 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"},
|
||||
{file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"},
|
||||
@ -5581,8 +5579,6 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"},
|
||||
@ -6115,26 +6111,31 @@ python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "PyMuPDF-1.23.26-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:645a05321aecc8c45739f71f0eb574ce33138d19189582ffa5241fea3a8e2549"},
|
||||
{file = "PyMuPDF-1.23.26-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:2dfc9e010669ae92fade6fb72aaea49ebe3b8dcd7ee4dcbbe50115abcaa4d3fe"},
|
||||
{file = "PyMuPDF-1.23.26-cp310-none-manylinux2014_aarch64.whl", hash = "sha256:734ee380b3abd038602be79114194a3cb74ac102b7c943bcb333104575922c50"},
|
||||
{file = "PyMuPDF-1.23.26-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:b22f8d854f8196ad5b20308c1cebad3d5189ed9f0988acbafa043947ea7e6c55"},
|
||||
{file = "PyMuPDF-1.23.26-cp310-none-win32.whl", hash = "sha256:cc0f794e3466bc96b5bf79d42fbc1551428751e3fef38ebc10ac70396b676144"},
|
||||
{file = "PyMuPDF-1.23.26-cp310-none-win_amd64.whl", hash = "sha256:2eb701247d8e685a24e45899d1175f01a3ce5fc792a4431c91fbb68633b29298"},
|
||||
{file = "PyMuPDF-1.23.26-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:e2804a64bb57da414781e312fb0561f6be67658ad57ed4a73dce008b23fc70a6"},
|
||||
{file = "PyMuPDF-1.23.26-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:97b40bb22e3056874634617a90e0ed24a5172cf71791b9e25d1d91c6743bc567"},
|
||||
{file = "PyMuPDF-1.23.26-cp311-none-manylinux2014_aarch64.whl", hash = "sha256:fab8833559bc47ab26ce736f915b8fc1dd37c108049b90396f7cd5e1004d7593"},
|
||||
{file = "PyMuPDF-1.23.26-cp311-none-manylinux2014_x86_64.whl", hash = "sha256:f25aafd3e7fb9d7761a22acf2b67d704f04cc36d4dc33a3773f0eb3f4ec3606f"},
|
||||
{file = "PyMuPDF-1.23.26-cp311-none-win32.whl", hash = "sha256:05e672ed3e82caca7ef02a88ace30130b1dd392a1190f03b2b58ffe7aa331400"},
|
||||
{file = "PyMuPDF-1.23.26-cp311-none-win_amd64.whl", hash = "sha256:92b3c4dd4d0491d495f333be2d41f4e1c155a409bc9d04b5ff29655dccbf4655"},
|
||||
{file = "PyMuPDF-1.23.26-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:a217689ede18cc6991b4e6a78afee8a440b3075d53b9dec4ba5ef7487d4547e9"},
|
||||
{file = "PyMuPDF-1.23.26-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:42ad2b819b90ce1947e11b90ec5085889df0a2e3aa0207bc97ecacfc6157cabc"},
|
||||
{file = "PyMuPDF-1.23.26-cp312-none-manylinux2014_aarch64.whl", hash = "sha256:99607649f89a02bba7d8ebe96e2410664316adc95e9337f7dfeff6a154f93049"},
|
||||
{file = "PyMuPDF-1.23.26-cp312-none-manylinux2014_x86_64.whl", hash = "sha256:bb42d4b8407b4de7cb58c28f01449f16f32a6daed88afb41108f1aeb3552bdd4"},
|
||||
{file = "PyMuPDF-1.23.26-cp312-none-win32.whl", hash = "sha256:c40d044411615e6f0baa7d3d933b3032cf97e168c7fa77d1be8a46008c109aee"},
|
||||
{file = "PyMuPDF-1.23.26-cp312-none-win_amd64.whl", hash = "sha256:3f876533aa7f9a94bcd9a0225ce72571b7808260903fec1d95c120bc842fb52d"},
|
||||
{file = "PyMuPDF-1.23.26-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:52df831d46beb9ff494f5fba3e5d069af6d81f49abf6b6e799ee01f4f8fa6799"},
|
||||
{file = "PyMuPDF-1.23.26-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:0bbb0cf6593e53524f3fc26fb5e6ead17c02c64791caec7c4afe61b677dedf80"},
|
||||
{file = "PyMuPDF-1.23.26-cp38-none-manylinux2014_aarch64.whl", hash = "sha256:5ef4360f20015673c20cf59b7e19afc97168795188c584254ed3778cde43ce77"},
|
||||
{file = "PyMuPDF-1.23.26-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:d7cd88842b2e7f4c71eef4d87c98c35646b80b60e6375392d7ce40e519261f59"},
|
||||
{file = "PyMuPDF-1.23.26-cp38-none-win32.whl", hash = "sha256:6577e2f473625e2d0df5f5a3bf1e4519e94ae749733cc9937994d1b256687bfa"},
|
||||
{file = "PyMuPDF-1.23.26-cp38-none-win_amd64.whl", hash = "sha256:fbe1a3255b2cd0d769b2da2c4efdd0c0f30d4961a1aac02c0f75cf951b337aa4"},
|
||||
{file = "PyMuPDF-1.23.26-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:73fce034f2afea886a59ead2d0caedf27e2b2a8558b5da16d0286882e0b1eb82"},
|
||||
{file = "PyMuPDF-1.23.26-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:b3de8618b7cb5b36db611083840b3bcf09b11a893e2d8262f4e042102c7e65de"},
|
||||
{file = "PyMuPDF-1.23.26-cp39-none-manylinux2014_aarch64.whl", hash = "sha256:879e7f5ad35709d8760ab6103c3d5dac8ab8043a856ab3653fd324af7358ee87"},
|
||||
{file = "PyMuPDF-1.23.26-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:deee96c2fd415ded7b5070d8d5b2c60679aee6ed0e28ac0d2cb998060d835c2c"},
|
||||
{file = "PyMuPDF-1.23.26-cp39-none-win32.whl", hash = "sha256:9f7f4ef99dd8ac97fb0b852efa3dcbee515798078b6c79a6a13c7b1e7c5d41a4"},
|
||||
{file = "PyMuPDF-1.23.26-cp39-none-win_amd64.whl", hash = "sha256:ba9a54552c7afb9ec85432c765e2fa9a81413acfaa7d70db7c9b528297749e5b"},
|
||||
@ -6575,7 +6576,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
@ -9234,4 +9234,4 @@ extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "as
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "c604b9bee1c9e9318178a5c356ad4f7a1c9ccc94bd329b0e909f56066cc254fc"
|
||||
content-hash = "48ea73a94d06ae90f8f089017ae1bbcf9d37b2cc9957a44fb617785be0fe3236"
|
||||
|
@ -53,7 +53,7 @@ mwxml = {version = "^0.3.3", optional = true}
|
||||
esprima = {version = "^4.0.1", optional = true}
|
||||
streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"}
|
||||
psychicapi = {version = "^0.8.0", optional = true}
|
||||
cassio = {version = "^0.1.0", optional = true}
|
||||
cassio = {version = "^0.1.6", optional = true}
|
||||
sympy = {version = "^1.12", optional = true}
|
||||
rapidfuzz = {version = "^3.1.1", optional = true}
|
||||
jsonschema = {version = ">1", optional = true}
|
||||
@ -153,7 +153,7 @@ pytest-vcr = "^1.0.2"
|
||||
wrapt = "^1.15.0"
|
||||
openai = "^1"
|
||||
python-dotenv = "^1.0.0"
|
||||
cassio = "^0.1.0"
|
||||
cassio = "^0.1.6"
|
||||
tiktoken = ">=0.3.2,<0.6.0"
|
||||
anthropic = "^0.3.11"
|
||||
langchain-core = { path = "../core", develop = true }
|
||||
|
@ -1,10 +1,12 @@
|
||||
"""Test Cassandra functionality."""
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.vectorstores import Cassandra
|
||||
from langchain_community.vectorstores.cassandra import SetupMode
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
AngularTwoDimensionalEmbeddings,
|
||||
ConsistentFakeEmbeddings,
|
||||
@ -46,31 +48,77 @@ def _vectorstore_from_texts(
|
||||
)
|
||||
|
||||
|
||||
def test_cassandra() -> None:
|
||||
async def _vectorstore_from_texts_async(
|
||||
texts: List[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings,
|
||||
drop: bool = True,
|
||||
) -> Cassandra:
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
keyspace = "vector_test_keyspace"
|
||||
table_name = "vector_test_table"
|
||||
# get db connection
|
||||
cluster = Cluster()
|
||||
session = cluster.connect()
|
||||
# ensure keyspace exists
|
||||
session.execute(
|
||||
(
|
||||
f"CREATE KEYSPACE IF NOT EXISTS {keyspace} "
|
||||
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
|
||||
)
|
||||
)
|
||||
# drop table if required
|
||||
if drop:
|
||||
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{table_name}")
|
||||
#
|
||||
return await Cassandra.afrom_texts(
|
||||
texts,
|
||||
embedding_class(),
|
||||
metadatas=metadatas,
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
table_name=table_name,
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
)
|
||||
|
||||
|
||||
async def test_cassandra() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = _vectorstore_from_texts(texts)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
output = await docsearch.asimilarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
def test_cassandra_with_score() -> None:
|
||||
async def test_cassandra_with_score() -> None:
|
||||
"""Test end to end construction and search with scores and IDs."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _vectorstore_from_texts(texts, metadatas=metadatas)
|
||||
output = docsearch.similarity_search_with_score("foo", k=3)
|
||||
docs = [o[0] for o in output]
|
||||
scores = [o[1] for o in output]
|
||||
assert docs == [
|
||||
|
||||
expected_docs = [
|
||||
Document(page_content="foo", metadata={"page": "0.0"}),
|
||||
Document(page_content="bar", metadata={"page": "1.0"}),
|
||||
Document(page_content="baz", metadata={"page": "2.0"}),
|
||||
]
|
||||
|
||||
output = docsearch.similarity_search_with_score("foo", k=3)
|
||||
docs = [o[0] for o in output]
|
||||
scores = [o[1] for o in output]
|
||||
assert docs == expected_docs
|
||||
assert scores[0] > scores[1] > scores[2]
|
||||
|
||||
output = await docsearch.asimilarity_search_with_score("foo", k=3)
|
||||
docs = [o[0] for o in output]
|
||||
scores = [o[1] for o in output]
|
||||
assert docs == expected_docs
|
||||
assert scores[0] > scores[1] > scores[2]
|
||||
|
||||
|
||||
def test_cassandra_max_marginal_relevance_search() -> None:
|
||||
async def test_cassandra_max_marginal_relevance_search() -> None:
|
||||
"""
|
||||
Test end to end construction and MMR search.
|
||||
The embedding function used here ensures `texts` become
|
||||
@ -91,17 +139,26 @@ def test_cassandra_max_marginal_relevance_search() -> None:
|
||||
docsearch = _vectorstore_from_texts(
|
||||
texts, metadatas=metadatas, embedding_class=AngularTwoDimensionalEmbeddings
|
||||
)
|
||||
output = docsearch.max_marginal_relevance_search("0.0", k=2, fetch_k=3)
|
||||
output_set = {
|
||||
(mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output
|
||||
}
|
||||
assert output_set == {
|
||||
|
||||
expected_set = {
|
||||
("+0.25", "2.0"),
|
||||
("-0.124", "0.0"),
|
||||
}
|
||||
|
||||
output = docsearch.max_marginal_relevance_search("0.0", k=2, fetch_k=3)
|
||||
output_set = {
|
||||
(mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output
|
||||
}
|
||||
assert output_set == expected_set
|
||||
|
||||
def test_cassandra_add_extra() -> None:
|
||||
output = await docsearch.amax_marginal_relevance_search("0.0", k=2, fetch_k=3)
|
||||
output_set = {
|
||||
(mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output
|
||||
}
|
||||
assert output_set == expected_set
|
||||
|
||||
|
||||
def test_cassandra_add_texts() -> None:
|
||||
"""Test end to end construction with further insertions."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
@ -115,12 +172,25 @@ def test_cassandra_add_extra() -> None:
|
||||
assert len(output) == 6
|
||||
|
||||
|
||||
async def test_cassandra_aadd_texts() -> None:
|
||||
"""Test end to end construction with further insertions."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _vectorstore_from_texts(texts, metadatas=metadatas)
|
||||
|
||||
texts2 = ["foo2", "bar2", "baz2"]
|
||||
metadatas2 = [{"page": i + 3} for i in range(len(texts))]
|
||||
await docsearch.aadd_texts(texts2, metadatas2)
|
||||
|
||||
output = await docsearch.asimilarity_search("foo", k=10)
|
||||
assert len(output) == 6
|
||||
|
||||
|
||||
def test_cassandra_no_drop() -> None:
|
||||
"""Test end to end construction and re-opening the same index."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _vectorstore_from_texts(texts, metadatas=metadatas)
|
||||
del docsearch
|
||||
_vectorstore_from_texts(texts, metadatas=metadatas)
|
||||
|
||||
texts2 = ["foo2", "bar2", "baz2"]
|
||||
docsearch = _vectorstore_from_texts(texts2, metadatas=metadatas, drop=False)
|
||||
@ -129,6 +199,21 @@ def test_cassandra_no_drop() -> None:
|
||||
assert len(output) == 6
|
||||
|
||||
|
||||
async def test_cassandra_no_drop_async() -> None:
|
||||
"""Test end to end construction and re-opening the same index."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
await _vectorstore_from_texts_async(texts, metadatas=metadatas)
|
||||
|
||||
texts2 = ["foo2", "bar2", "baz2"]
|
||||
docsearch = await _vectorstore_from_texts_async(
|
||||
texts2, metadatas=metadatas, drop=False
|
||||
)
|
||||
|
||||
output = await docsearch.asimilarity_search("foo", k=10)
|
||||
assert len(output) == 6
|
||||
|
||||
|
||||
def test_cassandra_delete() -> None:
|
||||
"""Test delete methods from vector store."""
|
||||
texts = ["foo", "bar", "baz", "gni"]
|
||||
@ -155,3 +240,31 @@ def test_cassandra_delete() -> None:
|
||||
time.sleep(0.3)
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 0
|
||||
|
||||
|
||||
async def test_cassandra_adelete() -> None:
|
||||
"""Test delete methods from vector store."""
|
||||
texts = ["foo", "bar", "baz", "gni"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = await _vectorstore_from_texts_async([], metadatas=metadatas)
|
||||
|
||||
ids = await docsearch.aadd_texts(texts, metadatas)
|
||||
output = await docsearch.asimilarity_search("foo", k=10)
|
||||
assert len(output) == 4
|
||||
|
||||
await docsearch.adelete_by_document_id(ids[0])
|
||||
output = await docsearch.asimilarity_search("foo", k=10)
|
||||
assert len(output) == 3
|
||||
|
||||
await docsearch.adelete(ids[1:3])
|
||||
output = await docsearch.asimilarity_search("foo", k=10)
|
||||
assert len(output) == 1
|
||||
|
||||
await docsearch.adelete(["not-existing"])
|
||||
output = await docsearch.asimilarity_search("foo", k=10)
|
||||
assert len(output) == 1
|
||||
|
||||
await docsearch.aclear()
|
||||
await asyncio.sleep(0.3)
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 0
|
||||
|
Loading…
Reference in New Issue
Block a user