mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
Second Attempt - Add concurrent insertion of vector rows in the Cassandra Vector Store (#7017)
Retrying with the same improvements as in #6772, this time trying not to mess up with branches. @rlancemartin doing a fresh new PR from a branch with a new name. This should do. Thank you for your help! --------- Co-authored-by: Jonathan Ellis <jbellis@datastax.com> Co-authored-by: rlm <pexpresss31@gmail.com>
This commit is contained in:
parent
3bfe7cf467
commit
8d2281a8ca
@ -16,6 +16,16 @@ pip install cassio
|
||||
|
||||
|
||||
|
||||
## Vector Store
|
||||
|
||||
See a [usage example](/docs/modules/data_connection/vectorstores/integrations/cassandra.html).
|
||||
|
||||
```python
|
||||
from langchain.memory import CassandraChatMessageHistory
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Memory
|
||||
|
||||
See a [usage example](/docs/modules/memory/integrations/cassandra_chat_message_history.html).
|
||||
|
@ -23,7 +23,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install \"cassio>=0.0.5\""
|
||||
"!pip install \"cassio>=0.0.7\""
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -44,14 +44,16 @@
|
||||
"import os\n",
|
||||
"import getpass\n",
|
||||
"\n",
|
||||
"database_mode = (input('\\n(L)ocal Cassandra or (A)stra DB? ')).upper()\n",
|
||||
"database_mode = (input('\\n(C)assandra or (A)stra DB? ')).upper()\n",
|
||||
"\n",
|
||||
"keyspace_name = input('\\nKeyspace name? ')\n",
|
||||
"\n",
|
||||
"if database_mode == 'A':\n",
|
||||
" ASTRA_DB_APPLICATION_TOKEN = getpass.getpass('\\nAstra DB Token (\"AstraCS:...\") ')\n",
|
||||
" #\n",
|
||||
" ASTRA_DB_SECURE_BUNDLE_PATH = input('Full path to your Secure Connect Bundle? ')"
|
||||
" ASTRA_DB_SECURE_BUNDLE_PATH = input('Full path to your Secure Connect Bundle? ')\n",
|
||||
"elif database_mode == 'C':\n",
|
||||
" CASSANDRA_CONTACT_POINTS = input('Contact points? (comma-separated, empty for localhost) ').strip()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -72,7 +74,14 @@
|
||||
"from cassandra.cluster import Cluster\n",
|
||||
"from cassandra.auth import PlainTextAuthProvider\n",
|
||||
"\n",
|
||||
"if database_mode == 'L':\n",
|
||||
"if database_mode == 'C':\n",
|
||||
" if CASSANDRA_CONTACT_POINTS:\n",
|
||||
" cluster = Cluster([\n",
|
||||
" cp.strip()\n",
|
||||
" for cp in CASSANDRA_CONTACT_POINTS.split(',')\n",
|
||||
" if cp.strip()\n",
|
||||
" ])\n",
|
||||
" else:\n",
|
||||
" cluster = Cluster()\n",
|
||||
" session = cluster.connect()\n",
|
||||
"elif database_mode == 'A':\n",
|
||||
@ -261,7 +270,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
"version": "3.10.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -1,8 +1,8 @@
|
||||
"""Wrapper around Cassandra vector-store capabilities, based on cassIO."""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import typing
|
||||
import uuid
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar
|
||||
|
||||
import numpy as np
|
||||
@ -17,14 +17,6 @@ from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
CVST = TypeVar("CVST", bound="Cassandra")
|
||||
|
||||
# a positive number of seconds to expire entries, or None for no expiration.
|
||||
CASSANDRA_VECTORSTORE_DEFAULT_TTL_SECONDS = None
|
||||
|
||||
|
||||
def _hash(_input: str) -> str:
|
||||
"""Use a deterministic hashing approach."""
|
||||
return hashlib.md5(_input.encode()).hexdigest()
|
||||
|
||||
|
||||
class Cassandra(VectorStore):
|
||||
"""Wrapper around Cassandra embeddings platform.
|
||||
@ -46,7 +38,7 @@ class Cassandra(VectorStore):
|
||||
|
||||
_embedding_dimension: int | None
|
||||
|
||||
def _getEmbeddingDimension(self) -> int:
|
||||
def _get_embedding_dimension(self) -> int:
|
||||
if self._embedding_dimension is None:
|
||||
self._embedding_dimension = len(
|
||||
self.embedding.embed_query("This is a sample sentence.")
|
||||
@ -59,7 +51,7 @@ class Cassandra(VectorStore):
|
||||
session: Session,
|
||||
keyspace: str,
|
||||
table_name: str,
|
||||
ttl_seconds: int | None = CASSANDRA_VECTORSTORE_DEFAULT_TTL_SECONDS,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
try:
|
||||
from cassio.vector import VectorTable
|
||||
@ -81,8 +73,8 @@ class Cassandra(VectorStore):
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
table=table_name,
|
||||
embedding_dimension=self._getEmbeddingDimension(),
|
||||
auto_id=False, # the `add_texts` contract admits user-provided ids
|
||||
embedding_dimension=self._get_embedding_dimension(),
|
||||
primary_key_type="TEXT",
|
||||
)
|
||||
|
||||
def delete_collection(self) -> None:
|
||||
@ -99,11 +91,27 @@ class Cassandra(VectorStore):
|
||||
def delete_by_document_id(self, document_id: str) -> None:
|
||||
return self.table.delete(document_id)
|
||||
|
||||
def delete(self, ids: List[str]) -> Optional[bool]:
|
||||
"""Delete by vector ID.
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
|
||||
Returns:
|
||||
Optional[bool]: True if deletion is successful,
|
||||
False otherwise, None if not implemented.
|
||||
"""
|
||||
for document_id in ids:
|
||||
self.delete_by_document_id(document_id)
|
||||
return True
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
batch_size: int = 16,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
@ -112,33 +120,39 @@ class Cassandra(VectorStore):
|
||||
texts (Iterable[str]): Texts to add to the vectorstore.
|
||||
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
|
||||
ids (Optional[List[str]], optional): Optional list of IDs.
|
||||
batch_size (int): Number of concurrent requests to send to the server.
|
||||
ttl_seconds (Optional[int], optional): Optional time-to-live
|
||||
for the added texts.
|
||||
|
||||
Returns:
|
||||
List[str]: List of IDs of the added texts.
|
||||
"""
|
||||
_texts = list(texts) # lest it be a generator or something
|
||||
if ids is None:
|
||||
# unless otherwise specified, we have deterministic IDs:
|
||||
# re-inserting an existing document will not create a duplicate.
|
||||
# (and effectively update the metadata)
|
||||
ids = [_hash(text) for text in _texts]
|
||||
ids = [uuid.uuid4().hex for _ in _texts]
|
||||
if metadatas is None:
|
||||
metadatas = [{} for _ in _texts]
|
||||
#
|
||||
ttl_seconds = kwargs.get("ttl_seconds", self.ttl_seconds)
|
||||
ttl_seconds = ttl_seconds or self.ttl_seconds
|
||||
#
|
||||
embedding_vectors = self.embedding.embed_documents(_texts)
|
||||
for text, embedding_vector, text_id, metadata in zip(
|
||||
_texts, embedding_vectors, ids, metadatas
|
||||
):
|
||||
self.table.put(
|
||||
document=text,
|
||||
embedding_vector=embedding_vector,
|
||||
document_id=text_id,
|
||||
metadata=metadata,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
#
|
||||
for i in range(0, len(_texts), batch_size):
|
||||
batch_texts = _texts[i : i + batch_size]
|
||||
batch_embedding_vectors = embedding_vectors[i : i + batch_size]
|
||||
batch_ids = ids[i : i + batch_size]
|
||||
batch_metadatas = metadatas[i : i + batch_size]
|
||||
|
||||
futures = [
|
||||
self.table.put_async(
|
||||
text, embedding_vector, text_id, metadata, ttl_seconds
|
||||
)
|
||||
for text, embedding_vector, text_id, metadata in zip(
|
||||
batch_texts, batch_embedding_vectors, batch_ids, batch_metadatas
|
||||
)
|
||||
]
|
||||
for future in futures:
|
||||
future.result()
|
||||
return ids
|
||||
|
||||
# id-returning search facilities
|
||||
@ -181,7 +195,6 @@ class Cassandra(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float, str]]:
|
||||
embedding_vector = self.embedding.embed_query(query)
|
||||
return self.similarity_search_with_score_id_by_vector(
|
||||
@ -219,12 +232,10 @@ class Cassandra(VectorStore):
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
#
|
||||
embedding_vector = self.embedding.embed_query(query)
|
||||
return self.similarity_search_by_vector(
|
||||
embedding_vector,
|
||||
k,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
@ -245,7 +256,6 @@ class Cassandra(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
embedding_vector = self.embedding.embed_query(query)
|
||||
return self.similarity_search_with_score_by_vector(
|
||||
@ -266,7 +276,6 @@ class Cassandra(VectorStore):
|
||||
return self.similarity_search_with_score(
|
||||
query,
|
||||
k,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
@ -352,6 +361,7 @@ class Cassandra(VectorStore):
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
batch_size: int = 16,
|
||||
**kwargs: Any,
|
||||
) -> CVST:
|
||||
"""Create a Cassandra vectorstore from raw texts.
|
||||
@ -378,6 +388,7 @@ class Cassandra(VectorStore):
|
||||
cls: Type[CVST],
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
batch_size: int = 16,
|
||||
**kwargs: Any,
|
||||
) -> CVST:
|
||||
"""Create a Cassandra vectorstore from a document list.
|
||||
|
552
poetry.lock
generated
552
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -113,7 +113,7 @@ esprima = {version = "^4.0.1", optional = true}
|
||||
openllm = {version = ">=0.1.19", optional = true}
|
||||
streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"}
|
||||
psychicapi = {version = "^0.8.0", optional = true}
|
||||
cassio = {version = "^0.0.6", optional = true}
|
||||
cassio = {version = "^0.0.7", optional = true}
|
||||
|
||||
[tool.poetry.group.docs.dependencies]
|
||||
autodoc_pydantic = "^1.8.0"
|
||||
@ -188,7 +188,7 @@ gptcache = "^0.1.9"
|
||||
promptlayer = "^0.1.80"
|
||||
tair = "^1.3.3"
|
||||
wikipedia = "^1"
|
||||
cassio = "^0.0.6"
|
||||
cassio = "^0.0.7"
|
||||
arxiv = "^1.4"
|
||||
mastodon-py = "^1.8.1"
|
||||
momento = "^1.5.0"
|
||||
|
@ -84,7 +84,7 @@ def test_cassandra_max_marginal_relevance_search() -> None:
|
||||
With fetch_k==3 and k==2, when query is at (1, ),
|
||||
one expects that v2 and v0 are returned (in some order).
|
||||
"""
|
||||
texts = ["-0.125", "+0.125", "+0.25", "+1.0"]
|
||||
texts = ["-0.124", "+0.127", "+0.25", "+1.0"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _vectorstore_from_texts(
|
||||
texts, metadatas=metadatas, embedding_class=AngularTwoDimensionalEmbeddings
|
||||
@ -95,7 +95,7 @@ def test_cassandra_max_marginal_relevance_search() -> None:
|
||||
}
|
||||
assert output_set == {
|
||||
("+0.25", 2),
|
||||
("-0.125", 0),
|
||||
("-0.124", 0),
|
||||
}
|
||||
|
||||
|
||||
@ -105,9 +105,9 @@ def test_cassandra_add_extra() -> None:
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _vectorstore_from_texts(texts, metadatas=metadatas)
|
||||
|
||||
docsearch.add_texts(texts, metadatas)
|
||||
texts2 = ["foo2", "bar2", "baz2"]
|
||||
docsearch.add_texts(texts2, metadatas)
|
||||
metadatas2 = [{"page": i + 3} for i in range(len(texts))]
|
||||
docsearch.add_texts(texts2, metadatas2)
|
||||
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 6
|
||||
@ -127,9 +127,37 @@ def test_cassandra_no_drop() -> None:
|
||||
assert len(output) == 6
|
||||
|
||||
|
||||
def test_cassandra_delete() -> None:
|
||||
"""Test delete methods from vector store."""
|
||||
texts = ["foo", "bar", "baz", "gni"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _vectorstore_from_texts([], metadatas=metadatas)
|
||||
|
||||
ids = docsearch.add_texts(texts, metadatas)
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 4
|
||||
|
||||
docsearch.delete_by_document_id(ids[0])
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 3
|
||||
|
||||
docsearch.delete(ids[1:3])
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 1
|
||||
|
||||
docsearch.delete(["not-existing"])
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 1
|
||||
|
||||
docsearch.clear()
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 0
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# test_cassandra()
|
||||
# test_cassandra_with_score()
|
||||
# test_cassandra_max_marginal_relevance_search()
|
||||
# test_cassandra_add_extra()
|
||||
# test_cassandra_no_drop()
|
||||
# test_cassandra_delete()
|
||||
|
Loading…
Reference in New Issue
Block a user