mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-16 20:42:01 +00:00
Cassandra Vector Store, add metadata filtering + improvements (#9280)
This PR addresses a few minor issues with the Cassandra vector store implementation and extends the store to support Metadata search. Thanks to the latest cassIO library (>=0.1.0), metadata filtering is available in the store. Further, - the "relevance" score is prevented from being flipped in the [0,1] interval, thus ensuring that 1 corresponds to the closest vector (this is related to how the underlying cassIO class returns the cosine difference); - bumped the cassIO package version both in the notebooks and the pyproject.toml; - adjusted the textfile location for the vector-store example after the reshuffling of the Langchain repo dir structure; - added demonstration of metadata filtering in the Cassandra vector store notebook; - better docstring for the Cassandra vector store class; - fixed test flakiness and removed offending out-of-place escape chars from a test module docstring; To my knowledge all relevant tests pass and mypy+black+ruff don't complain. (mypy gives unrelated errors in other modules, which clearly don't depend on the content of this PR). Thank you! Stefano --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
49694f6a3f
commit
415d38ae62
@ -23,7 +23,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install \"cassio>=0.0.7\""
|
"!pip install \"cassio>=0.1.0\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -155,7 +155,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.6"
|
"version": "3.10.12"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -23,7 +23,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install \"cassio>=0.0.7\""
|
"!pip install \"cassio>=0.1.0\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -152,7 +152,9 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from langchain.document_loaders import TextLoader\n",
|
"from langchain.document_loaders import TextLoader\n",
|
||||||
"\n",
|
"\n",
|
||||||
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
|
"SOURCE_FILE_NAME = \"../../modules/state_of_the_union.txt\"\n",
|
||||||
|
"\n",
|
||||||
|
"loader = TextLoader(SOURCE_FILE_NAME)\n",
|
||||||
"documents = loader.load()\n",
|
"documents = loader.load()\n",
|
||||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||||
"docs = text_splitter.split_documents(documents)\n",
|
"docs = text_splitter.split_documents(documents)\n",
|
||||||
@ -197,7 +199,7 @@
|
|||||||
"# table_name=table_name,\n",
|
"# table_name=table_name,\n",
|
||||||
"# )\n",
|
"# )\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# docsearch_preexisting.similarity_search(query, k=2)"
|
"# docs = docsearch_preexisting.similarity_search(query, k=2)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -253,6 +255,51 @@
|
|||||||
"for i, doc in enumerate(found_docs):\n",
|
"for i, doc in enumerate(found_docs):\n",
|
||||||
" print(f\"{i + 1}.\", doc.page_content, \"\\n\")"
|
" print(f\"{i + 1}.\", doc.page_content, \"\\n\")"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "da791c5f",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Metadata filtering\n",
|
||||||
|
"\n",
|
||||||
|
"You can specify filtering on metadata when running searches in the vector store. By default, when inserting documents, the only metadata is the `\"source\"` (but you can customize the metadata at insertion time).\n",
|
||||||
|
"\n",
|
||||||
|
"Since only one files was inserted, this is just a demonstration of how filters are passed:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "93f132fa",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"filter = {\"source\": SOURCE_FILE_NAME}\n",
|
||||||
|
"filtered_docs = docsearch.similarity_search(query, filter=filter, k=5)\n",
|
||||||
|
"print(f\"{len(filtered_docs)} documents retrieved.\")\n",
|
||||||
|
"print(f\"{filtered_docs[0].page_content[:64]} ...\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "1b413ec4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"filter = {\"source\": \"nonexisting_file.txt\"}\n",
|
||||||
|
"filtered_docs2 = docsearch.similarity_search(query, filter=filter)\n",
|
||||||
|
"print(f\"{len(filtered_docs2)} documents retrieved.\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "a0fea764",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Please visit the [cassIO documentation](https://cassio.org/frameworks/langchain/about/) for more on using vector stores with Langchain."
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -271,7 +318,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.6"
|
"version": "3.10.12"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -2,7 +2,18 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, TypeVar
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -18,11 +29,12 @@ CVST = TypeVar("CVST", bound="Cassandra")
|
|||||||
|
|
||||||
|
|
||||||
class Cassandra(VectorStore):
|
class Cassandra(VectorStore):
|
||||||
"""`Cassandra` vector store.
|
"""Wrapper around Apache Cassandra(R) for vector-store workloads.
|
||||||
|
|
||||||
It based on the Cassandra vector-store capabilities, based on cassIO.
|
To use it, you need a recent installation of the `cassio` library
|
||||||
There is no notion of a default table name, since each embedding
|
and a Cassandra cluster / Astra DB instance supporting vector capabilities.
|
||||||
function implies its own vector dimension, which is part of the schema.
|
|
||||||
|
Visit the cassio.org website for extensive quickstarts and code examples.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -31,12 +43,20 @@ class Cassandra(VectorStore):
|
|||||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
|
||||||
embeddings = OpenAIEmbeddings()
|
embeddings = OpenAIEmbeddings()
|
||||||
session = ...
|
session = ... # create your Cassandra session object
|
||||||
keyspace = 'my_keyspace'
|
keyspace = 'my_keyspace' # the keyspace should exist already
|
||||||
vectorstore = Cassandra(embeddings, session, keyspace, 'my_doc_archive')
|
table_name = 'my_vector_store'
|
||||||
|
vectorstore = Cassandra(embeddings, session, keyspace, table_name)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_embedding_dimension: int | None
|
_embedding_dimension: Union[int, None]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _filter_to_metadata(filter_dict: Optional[Dict[str, str]]) -> Dict[str, Any]:
|
||||||
|
if filter_dict is None:
|
||||||
|
return {}
|
||||||
|
else:
|
||||||
|
return filter_dict
|
||||||
|
|
||||||
def _get_embedding_dimension(self) -> int:
|
def _get_embedding_dimension(self) -> int:
|
||||||
if self._embedding_dimension is None:
|
if self._embedding_dimension is None:
|
||||||
@ -81,8 +101,18 @@ class Cassandra(VectorStore):
|
|||||||
def embeddings(self) -> Embeddings:
|
def embeddings(self) -> Embeddings:
|
||||||
return self.embedding
|
return self.embedding
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _dont_flip_the_cos_score(distance: float) -> float:
|
||||||
|
# the identity
|
||||||
|
return distance
|
||||||
|
|
||||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||||
return self._cosine_relevance_score_fn
|
"""
|
||||||
|
The underlying VectorTable already returns a "score proper",
|
||||||
|
i.e. one in [0, 1] where higher means more *similar*,
|
||||||
|
so here the final score transformation is not reversing the interval:
|
||||||
|
"""
|
||||||
|
return self._dont_flip_the_cos_score
|
||||||
|
|
||||||
def delete_collection(self) -> None:
|
def delete_collection(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -172,22 +202,24 @@ class Cassandra(VectorStore):
|
|||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
) -> List[Tuple[Document, float, str]]:
|
) -> List[Tuple[Document, float, str]]:
|
||||||
"""Return docs most similar to embedding vector.
|
"""Return docs most similar to embedding vector.
|
||||||
|
|
||||||
No support for `filter` query (on metadata) along with vector search.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embedding (str): Embedding to look up documents similar to.
|
embedding (str): Embedding to look up documents similar to.
|
||||||
k (int): Number of Documents to return. Defaults to 4.
|
k (int): Number of Documents to return. Defaults to 4.
|
||||||
Returns:
|
Returns:
|
||||||
List of (Document, score, id), the most similar to the query vector.
|
List of (Document, score, id), the most similar to the query vector.
|
||||||
"""
|
"""
|
||||||
|
search_metadata = self._filter_to_metadata(filter)
|
||||||
|
#
|
||||||
hits = self.table.search(
|
hits = self.table.search(
|
||||||
embedding_vector=embedding,
|
embedding_vector=embedding,
|
||||||
top_k=k,
|
top_k=k,
|
||||||
metric="cos",
|
metric="cos",
|
||||||
metric_threshold=None,
|
metric_threshold=None,
|
||||||
|
metadata=search_metadata,
|
||||||
)
|
)
|
||||||
# We stick to 'cos' distance as it can be normalized on a 0-1 axis
|
# We stick to 'cos' distance as it can be normalized on a 0-1 axis
|
||||||
# (1=most relevant), as required by this class' contract.
|
# (1=most relevant), as required by this class' contract.
|
||||||
@ -207,11 +239,13 @@ class Cassandra(VectorStore):
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
) -> List[Tuple[Document, float, str]]:
|
) -> List[Tuple[Document, float, str]]:
|
||||||
embedding_vector = self.embedding.embed_query(query)
|
embedding_vector = self.embedding.embed_query(query)
|
||||||
return self.similarity_search_with_score_id_by_vector(
|
return self.similarity_search_with_score_id_by_vector(
|
||||||
embedding=embedding_vector,
|
embedding=embedding_vector,
|
||||||
k=k,
|
k=k,
|
||||||
|
filter=filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
# id-unaware search facilities
|
# id-unaware search facilities
|
||||||
@ -219,11 +253,10 @@ class Cassandra(VectorStore):
|
|||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
"""Return docs most similar to embedding vector.
|
"""Return docs most similar to embedding vector.
|
||||||
|
|
||||||
No support for `filter` query (on metadata) along with vector search.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embedding (str): Embedding to look up documents similar to.
|
embedding (str): Embedding to look up documents similar to.
|
||||||
k (int): Number of Documents to return. Defaults to 4.
|
k (int): Number of Documents to return. Defaults to 4.
|
||||||
@ -235,6 +268,7 @@ class Cassandra(VectorStore):
|
|||||||
for (doc, score, docId) in self.similarity_search_with_score_id_by_vector(
|
for (doc, score, docId) in self.similarity_search_with_score_id_by_vector(
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
k=k,
|
k=k,
|
||||||
|
filter=filter,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -242,18 +276,21 @@ class Cassandra(VectorStore):
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
embedding_vector = self.embedding.embed_query(query)
|
embedding_vector = self.embedding.embed_query(query)
|
||||||
return self.similarity_search_by_vector(
|
return self.similarity_search_by_vector(
|
||||||
embedding_vector,
|
embedding_vector,
|
||||||
k,
|
k,
|
||||||
|
filter=filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
def similarity_search_by_vector(
|
def similarity_search_by_vector(
|
||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
return [
|
return [
|
||||||
@ -261,6 +298,7 @@ class Cassandra(VectorStore):
|
|||||||
for doc, _ in self.similarity_search_with_score_by_vector(
|
for doc, _ in self.similarity_search_with_score_by_vector(
|
||||||
embedding,
|
embedding,
|
||||||
k,
|
k,
|
||||||
|
filter=filter,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -268,11 +306,13 @@ class Cassandra(VectorStore):
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
embedding_vector = self.embedding.embed_query(query)
|
embedding_vector = self.embedding.embed_query(query)
|
||||||
return self.similarity_search_with_score_by_vector(
|
return self.similarity_search_with_score_by_vector(
|
||||||
embedding_vector,
|
embedding_vector,
|
||||||
k,
|
k,
|
||||||
|
filter=filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
def max_marginal_relevance_search_by_vector(
|
def max_marginal_relevance_search_by_vector(
|
||||||
@ -281,6 +321,7 @@ class Cassandra(VectorStore):
|
|||||||
k: int = 4,
|
k: int = 4,
|
||||||
fetch_k: int = 20,
|
fetch_k: int = 20,
|
||||||
lambda_mult: float = 0.5,
|
lambda_mult: float = 0.5,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Return docs selected using the maximal marginal relevance.
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
@ -296,11 +337,14 @@ class Cassandra(VectorStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List of Documents selected by maximal marginal relevance.
|
List of Documents selected by maximal marginal relevance.
|
||||||
"""
|
"""
|
||||||
|
search_metadata = self._filter_to_metadata(filter)
|
||||||
|
|
||||||
prefetchHits = self.table.search(
|
prefetchHits = self.table.search(
|
||||||
embedding_vector=embedding,
|
embedding_vector=embedding,
|
||||||
top_k=fetch_k,
|
top_k=fetch_k,
|
||||||
metric="cos",
|
metric="cos",
|
||||||
metric_threshold=None,
|
metric_threshold=None,
|
||||||
|
metadata=search_metadata,
|
||||||
)
|
)
|
||||||
# let the mmr utility pick the *indices* in the above array
|
# let the mmr utility pick the *indices* in the above array
|
||||||
mmrChosenIndices = maximal_marginal_relevance(
|
mmrChosenIndices = maximal_marginal_relevance(
|
||||||
@ -328,6 +372,7 @@ class Cassandra(VectorStore):
|
|||||||
k: int = 4,
|
k: int = 4,
|
||||||
fetch_k: int = 20,
|
fetch_k: int = 20,
|
||||||
lambda_mult: float = 0.5,
|
lambda_mult: float = 0.5,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Return docs selected using the maximal marginal relevance.
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
@ -350,6 +395,7 @@ class Cassandra(VectorStore):
|
|||||||
k,
|
k,
|
||||||
fetch_k,
|
fetch_k,
|
||||||
lambda_mult=lambda_mult,
|
lambda_mult=lambda_mult,
|
||||||
|
filter=filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Test Cassandra functionality."""
|
"""Test Cassandra functionality."""
|
||||||
|
import time
|
||||||
from typing import List, Optional, Type
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
@ -61,9 +62,9 @@ def test_cassandra_with_score() -> None:
|
|||||||
docs = [o[0] for o in output]
|
docs = [o[0] for o in output]
|
||||||
scores = [o[1] for o in output]
|
scores = [o[1] for o in output]
|
||||||
assert docs == [
|
assert docs == [
|
||||||
Document(page_content="foo", metadata={"page": 0}),
|
Document(page_content="foo", metadata={"page": "0.0"}),
|
||||||
Document(page_content="bar", metadata={"page": 1}),
|
Document(page_content="bar", metadata={"page": "1.0"}),
|
||||||
Document(page_content="baz", metadata={"page": 2}),
|
Document(page_content="baz", metadata={"page": "2.0"}),
|
||||||
]
|
]
|
||||||
assert scores[0] > scores[1] > scores[2]
|
assert scores[0] > scores[1] > scores[2]
|
||||||
|
|
||||||
@ -76,10 +77,10 @@ def test_cassandra_max_marginal_relevance_search() -> None:
|
|||||||
|
|
||||||
______ v2
|
______ v2
|
||||||
/ \
|
/ \
|
||||||
/ \ v1
|
/ | v1
|
||||||
v3 | . | query
|
v3 | . | query
|
||||||
\ / v0
|
| / v0
|
||||||
\______/ (N.B. very crude drawing)
|
|______/ (N.B. very crude drawing)
|
||||||
|
|
||||||
With fetch_k==3 and k==2, when query is at (1, ),
|
With fetch_k==3 and k==2, when query is at (1, ),
|
||||||
one expects that v2 and v0 are returned (in some order).
|
one expects that v2 and v0 are returned (in some order).
|
||||||
@ -94,8 +95,8 @@ def test_cassandra_max_marginal_relevance_search() -> None:
|
|||||||
(mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output
|
(mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output
|
||||||
}
|
}
|
||||||
assert output_set == {
|
assert output_set == {
|
||||||
("+0.25", 2),
|
("+0.25", "2.0"),
|
||||||
("-0.124", 0),
|
("-0.124", "0.0"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -150,6 +151,7 @@ def test_cassandra_delete() -> None:
|
|||||||
assert len(output) == 1
|
assert len(output) == 1
|
||||||
|
|
||||||
docsearch.clear()
|
docsearch.clear()
|
||||||
|
time.sleep(0.3)
|
||||||
output = docsearch.similarity_search("foo", k=10)
|
output = docsearch.similarity_search("foo", k=10)
|
||||||
assert len(output) == 0
|
assert len(output) == 0
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user