community: add support for callable filters in FAISS (#16190)

- **Description:**
Filtering in a FAISS vectorstores is very inflexible and doesn't allow
that many use case. I think supporting callable like this enables a lot:
regular expressions, condition on multiple keys etc. **Note** I had to
manually alter a test. I don't understand if it was falty to begin with
or if there is something funky going on.
- **Issue:** None
- **Dependencies:** None
- **Twitter handle:** None

Signed-off-by: thiswillbeyourgithub <26625900+thiswillbeyourgithub@users.noreply.github.com>
This commit is contained in:
thiswillbeyourgithub
2024-01-30 05:05:56 +01:00
committed by GitHub
parent 1703fe2361
commit 1d082359ee
3 changed files with 116 additions and 37 deletions

View File

@@ -273,7 +273,7 @@ class FAISS(VectorStore):
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
@@ -282,7 +282,9 @@ class FAISS(VectorStore):
Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
filter (Optional[Union[Callable, Dict[str, Any]]]): Filter by metadata.
Defaults to None. If a callable, it must take as input the
metadata dict of Document and return a bool.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
**kwargs: kwargs to be passed to similarity search. Can include:
@@ -299,6 +301,27 @@ class FAISS(VectorStore):
faiss.normalize_L2(vector)
scores, indices = self.index.search(vector, k if filter is None else fetch_k)
docs = []
if filter is not None:
if isinstance(filter, dict):
def filter_func(metadata):
if all(
metadata.get(key) in value
if isinstance(value, list)
else metadata.get(key) == value
for key, value in filter.items()
):
return True
return False
elif callable(filter):
filter_func = filter
else:
raise ValueError(
"filter must be a dict of metadata or "
f"a callable, not {type(filter)}"
)
for j, i in enumerate(indices[0]):
if i == -1:
# This happens when not enough docs are returned.
@@ -307,13 +330,8 @@ class FAISS(VectorStore):
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
if filter is not None:
filter = {
key: [value] if not isinstance(value, list) else value
for key, value in filter.items()
}
if all(doc.metadata.get(key) in value for key, value in filter.items()):
docs.append((doc, scores[0][j]))
if filter is not None and filter_func(doc.metadata):
docs.append((doc, scores[0][j]))
else:
docs.append((doc, scores[0][j]))
@@ -336,7 +354,7 @@ class FAISS(VectorStore):
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
@@ -345,7 +363,10 @@ class FAISS(VectorStore):
Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
filter (Optional[Dict[str, Any]]): Filter by metadata.
Defaults to None. If a callable, it must take as input the
metadata dict of Document and return a bool.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
**kwargs: kwargs to be passed to similarity search. Can include:
@@ -372,7 +393,7 @@ class FAISS(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
@@ -381,7 +402,10 @@ class FAISS(VectorStore):
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
filter (Optional[Dict[str, str]]): Filter by metadata.
Defaults to None. If a callable, it must take as input the
metadata dict of Document and return a bool.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
@@ -403,7 +427,7 @@ class FAISS(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
@@ -412,7 +436,10 @@ class FAISS(VectorStore):
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
filter (Optional[Dict[str, str]]): Filter by metadata.
Defaults to None. If a callable, it must take as input the
metadata dict of Document and return a bool.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
@@ -443,7 +470,10 @@ class FAISS(VectorStore):
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
filter (Optional[Dict[str, str]]): Filter by metadata.
Defaults to None. If a callable, it must take as input the
metadata dict of Document and return a bool.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
@@ -463,7 +493,7 @@ class FAISS(VectorStore):
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]:
@@ -472,7 +502,10 @@ class FAISS(VectorStore):
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
filter (Optional[Dict[str, str]]): Filter by metadata.
Defaults to None. If a callable, it must take as input the
metadata dict of Document and return a bool.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
@@ -492,7 +525,7 @@ class FAISS(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]:
@@ -517,7 +550,7 @@ class FAISS(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]:
@@ -545,7 +578,7 @@ class FAISS(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
) -> List[Tuple[Document, float]]:
"""Return docs and their similarity scores selected using the maximal marginal
relevance.
@@ -572,6 +605,24 @@ class FAISS(VectorStore):
)
if filter is not None:
filtered_indices = []
if isinstance(filter, dict):
def filter_func(metadata):
if all(
metadata.get(key) in value
if isinstance(value, list)
else metadata.get(key) == value
for key, value in filter.items()
):
return True
return False
elif callable(filter):
filter_func = filter
else:
raise ValueError(
"filter must be a dict of metadata or "
f"a callable, not {type(filter)}"
)
for i in indices[0]:
if i == -1:
# This happens when not enough docs are returned.
@@ -580,12 +631,7 @@ class FAISS(VectorStore):
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
if all(
doc.metadata.get(key) in value
if isinstance(value, list)
else doc.metadata.get(key) == value
for key, value in filter.items()
):
if filter_func(doc.metadata):
filtered_indices.append(i)
indices = np.array([filtered_indices])
# -1 happens when not enough docs are returned.
@@ -617,7 +663,7 @@ class FAISS(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
) -> List[Tuple[Document, float]]:
"""Return docs and their similarity scores selected using the maximal marginal
relevance asynchronously.
@@ -655,7 +701,7 @@ class FAISS(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
@@ -686,7 +732,7 @@ class FAISS(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance asynchronously.
@@ -719,7 +765,7 @@ class FAISS(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
@@ -756,7 +802,7 @@ class FAISS(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance asynchronously.
@@ -1110,7 +1156,7 @@ class FAISS(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
@@ -1139,7 +1185,7 @@ class FAISS(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]: