mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-16 04:21:52 +00:00
Cassandra Vector Store, add metadata filtering + improvements (#9280)
This PR addresses a few minor issues with the Cassandra vector store implementation and extends the store to support Metadata search. Thanks to the latest cassIO library (>=0.1.0), metadata filtering is available in the store. Further, - the "relevance" score is prevented from being flipped in the [0,1] interval, thus ensuring that 1 corresponds to the closest vector (this is related to how the underlying cassIO class returns the cosine difference); - bumped the cassIO package version both in the notebooks and the pyproject.toml; - adjusted the textfile location for the vector-store example after the reshuffling of the Langchain repo dir structure; - added demonstration of metadata filtering in the Cassandra vector store notebook; - better docstring for the Cassandra vector store class; - fixed test flakiness and removed offending out-of-place escape chars from a test module docstring; To my knowledge all relevant tests pass and mypy+black+ruff don't complain. (mypy gives unrelated errors in other modules, which clearly don't depend on the content of this PR). Thank you! Stefano --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
49694f6a3f
commit
415d38ae62
@ -23,7 +23,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install \"cassio>=0.0.7\""
|
||||
"!pip install \"cassio>=0.1.0\""
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -155,7 +155,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -23,7 +23,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install \"cassio>=0.0.7\""
|
||||
"!pip install \"cassio>=0.1.0\""
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -152,7 +152,9 @@
|
||||
"source": [
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"\n",
|
||||
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
|
||||
"SOURCE_FILE_NAME = \"../../modules/state_of_the_union.txt\"\n",
|
||||
"\n",
|
||||
"loader = TextLoader(SOURCE_FILE_NAME)\n",
|
||||
"documents = loader.load()\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||
"docs = text_splitter.split_documents(documents)\n",
|
||||
@ -197,7 +199,7 @@
|
||||
"# table_name=table_name,\n",
|
||||
"# )\n",
|
||||
"\n",
|
||||
"# docsearch_preexisting.similarity_search(query, k=2)"
|
||||
"# docs = docsearch_preexisting.similarity_search(query, k=2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -253,6 +255,51 @@
|
||||
"for i, doc in enumerate(found_docs):\n",
|
||||
" print(f\"{i + 1}.\", doc.page_content, \"\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "da791c5f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Metadata filtering\n",
|
||||
"\n",
|
||||
"You can specify filtering on metadata when running searches in the vector store. By default, when inserting documents, the only metadata is the `\"source\"` (but you can customize the metadata at insertion time).\n",
|
||||
"\n",
|
||||
"Since only one files was inserted, this is just a demonstration of how filters are passed:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "93f132fa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"filter = {\"source\": SOURCE_FILE_NAME}\n",
|
||||
"filtered_docs = docsearch.similarity_search(query, filter=filter, k=5)\n",
|
||||
"print(f\"{len(filtered_docs)} documents retrieved.\")\n",
|
||||
"print(f\"{filtered_docs[0].page_content[:64]} ...\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1b413ec4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"filter = {\"source\": \"nonexisting_file.txt\"}\n",
|
||||
"filtered_docs2 = docsearch.similarity_search(query, filter=filter)\n",
|
||||
"print(f\"{len(filtered_docs2)} documents retrieved.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a0fea764",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Please visit the [cassIO documentation](https://cassio.org/frameworks/langchain/about/) for more on using vector stores with Langchain."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@ -271,7 +318,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -2,7 +2,18 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import uuid
|
||||
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, TypeVar
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -18,11 +29,12 @@ CVST = TypeVar("CVST", bound="Cassandra")
|
||||
|
||||
|
||||
class Cassandra(VectorStore):
|
||||
"""`Cassandra` vector store.
|
||||
"""Wrapper around Apache Cassandra(R) for vector-store workloads.
|
||||
|
||||
It based on the Cassandra vector-store capabilities, based on cassIO.
|
||||
There is no notion of a default table name, since each embedding
|
||||
function implies its own vector dimension, which is part of the schema.
|
||||
To use it, you need a recent installation of the `cassio` library
|
||||
and a Cassandra cluster / Astra DB instance supporting vector capabilities.
|
||||
|
||||
Visit the cassio.org website for extensive quickstarts and code examples.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@ -31,12 +43,20 @@ class Cassandra(VectorStore):
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
session = ...
|
||||
keyspace = 'my_keyspace'
|
||||
vectorstore = Cassandra(embeddings, session, keyspace, 'my_doc_archive')
|
||||
session = ... # create your Cassandra session object
|
||||
keyspace = 'my_keyspace' # the keyspace should exist already
|
||||
table_name = 'my_vector_store'
|
||||
vectorstore = Cassandra(embeddings, session, keyspace, table_name)
|
||||
"""
|
||||
|
||||
_embedding_dimension: int | None
|
||||
_embedding_dimension: Union[int, None]
|
||||
|
||||
@staticmethod
|
||||
def _filter_to_metadata(filter_dict: Optional[Dict[str, str]]) -> Dict[str, Any]:
|
||||
if filter_dict is None:
|
||||
return {}
|
||||
else:
|
||||
return filter_dict
|
||||
|
||||
def _get_embedding_dimension(self) -> int:
|
||||
if self._embedding_dimension is None:
|
||||
@ -81,8 +101,18 @@ class Cassandra(VectorStore):
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding
|
||||
|
||||
@staticmethod
|
||||
def _dont_flip_the_cos_score(distance: float) -> float:
|
||||
# the identity
|
||||
return distance
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
return self._cosine_relevance_score_fn
|
||||
"""
|
||||
The underlying VectorTable already returns a "score proper",
|
||||
i.e. one in [0, 1] where higher means more *similar*,
|
||||
so here the final score transformation is not reversing the interval:
|
||||
"""
|
||||
return self._dont_flip_the_cos_score
|
||||
|
||||
def delete_collection(self) -> None:
|
||||
"""
|
||||
@ -172,22 +202,24 @@ class Cassandra(VectorStore):
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
) -> List[Tuple[Document, float, str]]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
No support for `filter` query (on metadata) along with vector search.
|
||||
|
||||
Args:
|
||||
embedding (str): Embedding to look up documents similar to.
|
||||
k (int): Number of Documents to return. Defaults to 4.
|
||||
Returns:
|
||||
List of (Document, score, id), the most similar to the query vector.
|
||||
"""
|
||||
search_metadata = self._filter_to_metadata(filter)
|
||||
#
|
||||
hits = self.table.search(
|
||||
embedding_vector=embedding,
|
||||
top_k=k,
|
||||
metric="cos",
|
||||
metric_threshold=None,
|
||||
metadata=search_metadata,
|
||||
)
|
||||
# We stick to 'cos' distance as it can be normalized on a 0-1 axis
|
||||
# (1=most relevant), as required by this class' contract.
|
||||
@ -207,11 +239,13 @@ class Cassandra(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
) -> List[Tuple[Document, float, str]]:
|
||||
embedding_vector = self.embedding.embed_query(query)
|
||||
return self.similarity_search_with_score_id_by_vector(
|
||||
embedding=embedding_vector,
|
||||
k=k,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
# id-unaware search facilities
|
||||
@ -219,11 +253,10 @@ class Cassandra(VectorStore):
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
No support for `filter` query (on metadata) along with vector search.
|
||||
|
||||
Args:
|
||||
embedding (str): Embedding to look up documents similar to.
|
||||
k (int): Number of Documents to return. Defaults to 4.
|
||||
@ -235,6 +268,7 @@ class Cassandra(VectorStore):
|
||||
for (doc, score, docId) in self.similarity_search_with_score_id_by_vector(
|
||||
embedding=embedding,
|
||||
k=k,
|
||||
filter=filter,
|
||||
)
|
||||
]
|
||||
|
||||
@ -242,18 +276,21 @@ class Cassandra(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
embedding_vector = self.embedding.embed_query(query)
|
||||
return self.similarity_search_by_vector(
|
||||
embedding_vector,
|
||||
k,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return [
|
||||
@ -261,6 +298,7 @@ class Cassandra(VectorStore):
|
||||
for doc, _ in self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
filter=filter,
|
||||
)
|
||||
]
|
||||
|
||||
@ -268,11 +306,13 @@ class Cassandra(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
embedding_vector = self.embedding.embed_query(query)
|
||||
return self.similarity_search_with_score_by_vector(
|
||||
embedding_vector,
|
||||
k,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
@ -281,6 +321,7 @@ class Cassandra(VectorStore):
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
@ -296,11 +337,14 @@ class Cassandra(VectorStore):
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
search_metadata = self._filter_to_metadata(filter)
|
||||
|
||||
prefetchHits = self.table.search(
|
||||
embedding_vector=embedding,
|
||||
top_k=fetch_k,
|
||||
metric="cos",
|
||||
metric_threshold=None,
|
||||
metadata=search_metadata,
|
||||
)
|
||||
# let the mmr utility pick the *indices* in the above array
|
||||
mmrChosenIndices = maximal_marginal_relevance(
|
||||
@ -328,6 +372,7 @@ class Cassandra(VectorStore):
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
@ -350,6 +395,7 @@ class Cassandra(VectorStore):
|
||||
k,
|
||||
fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Test Cassandra functionality."""
|
||||
import time
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
@ -61,9 +62,9 @@ def test_cassandra_with_score() -> None:
|
||||
docs = [o[0] for o in output]
|
||||
scores = [o[1] for o in output]
|
||||
assert docs == [
|
||||
Document(page_content="foo", metadata={"page": 0}),
|
||||
Document(page_content="bar", metadata={"page": 1}),
|
||||
Document(page_content="baz", metadata={"page": 2}),
|
||||
Document(page_content="foo", metadata={"page": "0.0"}),
|
||||
Document(page_content="bar", metadata={"page": "1.0"}),
|
||||
Document(page_content="baz", metadata={"page": "2.0"}),
|
||||
]
|
||||
assert scores[0] > scores[1] > scores[2]
|
||||
|
||||
@ -76,10 +77,10 @@ def test_cassandra_max_marginal_relevance_search() -> None:
|
||||
|
||||
______ v2
|
||||
/ \
|
||||
/ \ v1
|
||||
/ | v1
|
||||
v3 | . | query
|
||||
\ / v0
|
||||
\______/ (N.B. very crude drawing)
|
||||
| / v0
|
||||
|______/ (N.B. very crude drawing)
|
||||
|
||||
With fetch_k==3 and k==2, when query is at (1, ),
|
||||
one expects that v2 and v0 are returned (in some order).
|
||||
@ -94,8 +95,8 @@ def test_cassandra_max_marginal_relevance_search() -> None:
|
||||
(mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output
|
||||
}
|
||||
assert output_set == {
|
||||
("+0.25", 2),
|
||||
("-0.124", 0),
|
||||
("+0.25", "2.0"),
|
||||
("-0.124", "0.0"),
|
||||
}
|
||||
|
||||
|
||||
@ -150,6 +151,7 @@ def test_cassandra_delete() -> None:
|
||||
assert len(output) == 1
|
||||
|
||||
docsearch.clear()
|
||||
time.sleep(0.3)
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 0
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user