Vector store support for Cassandra (#6426)

This addresses #6291 adding support for using Cassandra (and compatible
databases, such as DataStax Astra DB) as a [Vector
Store](https://cwiki.apache.org/confluence/display/CASSANDRA/CEP-30%3A+Approximate+Nearest+Neighbor(ANN)+Vector+Search+via+Storage-Attached+Indexes).

A new class `Cassandra` is introduced, which complies with the contract
and interface for a vector store, along with the corresponding
integration test, a sample notebook and modified dependency toml.

Dependencies: the implementation relies on the library `cassio`, which
simplifies interacting with Cassandra for ML- and LLM-oriented
workloads. CassIO, in turn, uses the `cassandra-driver` low-lever
drivers to communicate with the database. The former is added as
optional dependency (+ in `extended_testing`), the latter was already in
the project.

Integration testing relies on a locally-running instance of Cassandra.
[Here](https://cassio.org/more_info/#use-a-local-vector-capable-cassandra)
a detailed description can be found on how to compile and run it (at the
time of writing the feature has not made it yet to a release).

During development of the integration tests, I added a new "fake
embedding" class for what I consider a more controlled way of testing
the MMR search method. Likewise, I had to amend what looked like a
glitch in the behaviour of `ConsistentFakeEmbeddings` whereby an
`embed_query` call would have bypassed storage of the requested text in
the class cache for use in later repeated invocations.

@dev2049 might be the right person to tag here for a review. Thank you!

---------

Co-authored-by: rlm <pexpresss31@gmail.com>
This commit is contained in:
Stefano Lottini
2023-06-20 19:46:20 +02:00
committed by GitHub
parent cac6e45a67
commit 22af93d851
5 changed files with 835 additions and 0 deletions

View File

@@ -0,0 +1,269 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "683953b3",
"metadata": {},
"source": [
"# Cassandra\n",
"\n",
">[Apache Cassandra®](https://cassandra.apache.org) is a NoSQL, row-oriented, highly scalable and highly available database.\n",
"\n",
"Newest Cassandra releases natively [support](https://cwiki.apache.org/confluence/display/CASSANDRA/CEP-30%3A+Approximate+Nearest+Neighbor(ANN)+Vector+Search+via+Storage-Attached+Indexes) Vector Similarity Search.\n",
"\n",
"To run this notebook you need either a running Cassandra cluster equipped with Vector Search capabilities (in pre-release at the time of writing) or a DataStax Astra DB instance running in the cloud (you can get one for free at [datastax.com](https://astra.datastax.com)). Check [cassio.org](https://cassio.org/start_here/) for more information."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b4c41cad-08ef-4f72-a545-2151e4598efe",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!pip install \"cassio>=0.0.5\""
]
},
{
"cell_type": "markdown",
"id": "b7e46bb0",
"metadata": {},
"source": [
"### Please provide database connection parameters and secrets:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "36128a32",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import getpass\n",
"\n",
"database_mode = (input('\\n(L)ocal Cassandra or (A)stra DB? ')).upper()\n",
"\n",
"keyspace_name = input('\\nKeyspace name? ')\n",
"\n",
"if database_mode == 'A':\n",
" ASTRA_DB_APPLICATION_TOKEN = getpass.getpass('\\nAstra DB Token (\"AstraCS:...\") ')\n",
" #\n",
" ASTRA_DB_SECURE_BUNDLE_PATH = input('Full path to your Secure Connect Bundle? ')"
]
},
{
"cell_type": "markdown",
"id": "4f22aac2",
"metadata": {},
"source": [
"#### depending on whether local or cloud-based Astra DB, create the corresponding database connection \"Session\" object"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "677f8576",
"metadata": {},
"outputs": [],
"source": [
"from cassandra.cluster import Cluster\n",
"from cassandra.auth import PlainTextAuthProvider\n",
"\n",
"if database_mode == 'L':\n",
" cluster = Cluster()\n",
" session = cluster.connect()\n",
"elif database_mode == 'A':\n",
" ASTRA_DB_CLIENT_ID = \"token\"\n",
" cluster = Cluster(\n",
" cloud={\n",
" \"secure_connect_bundle\": ASTRA_DB_SECURE_BUNDLE_PATH,\n",
" },\n",
" auth_provider=PlainTextAuthProvider(\n",
" ASTRA_DB_CLIENT_ID,\n",
" ASTRA_DB_APPLICATION_TOKEN,\n",
" ),\n",
" )\n",
" session = cluster.connect()\n",
"else:\n",
" raise NotImplementedError"
]
},
{
"cell_type": "markdown",
"id": "320af802-9271-46ee-948f-d2453933d44b",
"metadata": {},
"source": [
"### Please provide OpenAI access key\n",
"\n",
"We want to use `OpenAIEmbeddings` so we have to get the OpenAI API Key."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ffea66e4-bc23-46a9-9580-b348dfe7b7a7",
"metadata": {},
"outputs": [],
"source": [
"os.environ['OPENAI_API_KEY'] = getpass.getpass('OpenAI API Key:')"
]
},
{
"cell_type": "markdown",
"id": "e98a139b",
"metadata": {},
"source": [
"### Creation and usage of the Vector Store"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aac9563e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import Cassandra\n",
"from langchain.document_loaders import TextLoader"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a3c3999a",
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import TextLoader\n",
"loader = TextLoader('../../../state_of_the_union.txt')\n",
"documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)\n",
"\n",
"embedding_function = OpenAIEmbeddings()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6e104aee",
"metadata": {},
"outputs": [],
"source": [
"table_name = 'my_vector_db_table'\n",
"\n",
"docsearch = Cassandra.from_documents(\n",
" documents=docs,\n",
" embedding=embedding_function,\n",
" session=session,\n",
" keyspace=keyspace_name,\n",
" table_name=table_name,\n",
")\n",
"\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = docsearch.similarity_search(query)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f509ee02",
"metadata": {},
"outputs": [],
"source": [
"## if you already have an index, you can load it and use it like this:\n",
"\n",
"# docsearch_preexisting = Cassandra(\n",
"# embedding=embedding_function,\n",
"# session=session,\n",
"# keyspace=keyspace_name,\n",
"# table_name=table_name,\n",
"# )\n",
"\n",
"# docsearch_preexisting.similarity_search(query, k=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c608226",
"metadata": {},
"outputs": [],
"source": [
"print(docs[0].page_content)"
]
},
{
"cell_type": "markdown",
"id": "d46d1452",
"metadata": {},
"source": [
"### Maximal Marginal Relevance Searches\n",
"\n",
"In addition to using similarity search in the retriever object, you can also use `mmr` as retriever.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a359ed74",
"metadata": {},
"outputs": [],
"source": [
"retriever = docsearch.as_retriever(search_type=\"mmr\")\n",
"matched_docs = retriever.get_relevant_documents(query)\n",
"for i, d in enumerate(matched_docs):\n",
" print(f\"\\n## Document {i}\\n\")\n",
" print(d.page_content)"
]
},
{
"cell_type": "markdown",
"id": "7c477287",
"metadata": {},
"source": [
"Or use `max_marginal_relevance_search` directly:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ca82740",
"metadata": {},
"outputs": [],
"source": [
"found_docs = docsearch.max_marginal_relevance_search(query, k=2, fetch_k=10)\n",
"for i, doc in enumerate(found_docs):\n",
" print(f\"{i + 1}.\", doc.page_content, \"\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}