mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
community: add support for callable filters in FAISS (#16190)
- **Description:** Filtering in a FAISS vectorstores is very inflexible and doesn't allow that many use case. I think supporting callable like this enables a lot: regular expressions, condition on multiple keys etc. **Note** I had to manually alter a test. I don't understand if it was falty to begin with or if there is something funky going on. - **Issue:** None - **Dependencies:** None - **Twitter handle:** None Signed-off-by: thiswillbeyourgithub <26625900+thiswillbeyourgithub@users.noreply.github.com>
This commit is contained in:
parent
1703fe2361
commit
1d082359ee
@ -416,7 +416,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Similarity Search with filtering\n",
|
||||
"FAISS vectorstore can also support filtering, since the FAISS does not natively support filtering we have to do it manually. This is done by first fetching more results than `k` and then filtering them. You can filter the documents based on metadata. You can also set the `fetch_k` parameter when calling any search method to set how many documents you want to fetch before filtering. Here is a small example:"
|
||||
"FAISS vectorstore can also support filtering, since the FAISS does not natively support filtering we have to do it manually. This is done by first fetching more results than `k` and then filtering them. This filter is either a callble that takes as input a metadata dict and returns a bool, or a metadata dict where each missing key is ignored and each present k must be in a list of values. You can also set the `fetch_k` parameter when calling any search method to set how many documents you want to fetch before filtering. Here is a small example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -480,6 +480,8 @@
|
||||
],
|
||||
"source": [
|
||||
"results_with_scores = db.similarity_search_with_score(\"foo\", filter=dict(page=1))\n",
|
||||
"# Or with a callable:\n",
|
||||
"# results_with_scores = db.similarity_search_with_score(\"foo\", filter=lambda d: d[\"page\"] == 1)\n",
|
||||
"for doc, score in results_with_scores:\n",
|
||||
" print(f\"Content: {doc.page_content}, Metadata: {doc.metadata}, Score: {score}\")"
|
||||
]
|
||||
|
@ -273,7 +273,7 @@ class FAISS(VectorStore):
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
@ -282,7 +282,9 @@ class FAISS(VectorStore):
|
||||
Args:
|
||||
embedding: Embedding vector to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
|
||||
filter (Optional[Union[Callable, Dict[str, Any]]]): Filter by metadata.
|
||||
Defaults to None. If a callable, it must take as input the
|
||||
metadata dict of Document and return a bool.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
**kwargs: kwargs to be passed to similarity search. Can include:
|
||||
@ -299,6 +301,27 @@ class FAISS(VectorStore):
|
||||
faiss.normalize_L2(vector)
|
||||
scores, indices = self.index.search(vector, k if filter is None else fetch_k)
|
||||
docs = []
|
||||
|
||||
if filter is not None:
|
||||
if isinstance(filter, dict):
|
||||
|
||||
def filter_func(metadata):
|
||||
if all(
|
||||
metadata.get(key) in value
|
||||
if isinstance(value, list)
|
||||
else metadata.get(key) == value
|
||||
for key, value in filter.items()
|
||||
):
|
||||
return True
|
||||
return False
|
||||
elif callable(filter):
|
||||
filter_func = filter
|
||||
else:
|
||||
raise ValueError(
|
||||
"filter must be a dict of metadata or "
|
||||
f"a callable, not {type(filter)}"
|
||||
)
|
||||
|
||||
for j, i in enumerate(indices[0]):
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
@ -307,13 +330,8 @@ class FAISS(VectorStore):
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
if filter is not None:
|
||||
filter = {
|
||||
key: [value] if not isinstance(value, list) else value
|
||||
for key, value in filter.items()
|
||||
}
|
||||
if all(doc.metadata.get(key) in value for key, value in filter.items()):
|
||||
docs.append((doc, scores[0][j]))
|
||||
if filter is not None and filter_func(doc.metadata):
|
||||
docs.append((doc, scores[0][j]))
|
||||
else:
|
||||
docs.append((doc, scores[0][j]))
|
||||
|
||||
@ -336,7 +354,7 @@ class FAISS(VectorStore):
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
@ -345,7 +363,10 @@ class FAISS(VectorStore):
|
||||
Args:
|
||||
embedding: Embedding vector to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
|
||||
filter (Optional[Dict[str, Any]]): Filter by metadata.
|
||||
Defaults to None. If a callable, it must take as input the
|
||||
metadata dict of Document and return a bool.
|
||||
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
**kwargs: kwargs to be passed to similarity search. Can include:
|
||||
@ -372,7 +393,7 @@ class FAISS(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
@ -381,7 +402,10 @@ class FAISS(VectorStore):
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata.
|
||||
Defaults to None. If a callable, it must take as input the
|
||||
metadata dict of Document and return a bool.
|
||||
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
@ -403,7 +427,7 @@ class FAISS(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
@ -412,7 +436,10 @@ class FAISS(VectorStore):
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata.
|
||||
Defaults to None. If a callable, it must take as input the
|
||||
metadata dict of Document and return a bool.
|
||||
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
@ -443,7 +470,10 @@ class FAISS(VectorStore):
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata.
|
||||
Defaults to None. If a callable, it must take as input the
|
||||
metadata dict of Document and return a bool.
|
||||
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
@ -463,7 +493,7 @@ class FAISS(VectorStore):
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
@ -472,7 +502,10 @@ class FAISS(VectorStore):
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata.
|
||||
Defaults to None. If a callable, it must take as input the
|
||||
metadata dict of Document and return a bool.
|
||||
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
@ -492,7 +525,7 @@ class FAISS(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
@ -517,7 +550,7 @@ class FAISS(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
@ -545,7 +578,7 @@ class FAISS(VectorStore):
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and their similarity scores selected using the maximal marginal
|
||||
relevance.
|
||||
@ -572,6 +605,24 @@ class FAISS(VectorStore):
|
||||
)
|
||||
if filter is not None:
|
||||
filtered_indices = []
|
||||
if isinstance(filter, dict):
|
||||
|
||||
def filter_func(metadata):
|
||||
if all(
|
||||
metadata.get(key) in value
|
||||
if isinstance(value, list)
|
||||
else metadata.get(key) == value
|
||||
for key, value in filter.items()
|
||||
):
|
||||
return True
|
||||
return False
|
||||
elif callable(filter):
|
||||
filter_func = filter
|
||||
else:
|
||||
raise ValueError(
|
||||
"filter must be a dict of metadata or "
|
||||
f"a callable, not {type(filter)}"
|
||||
)
|
||||
for i in indices[0]:
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
@ -580,12 +631,7 @@ class FAISS(VectorStore):
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
if all(
|
||||
doc.metadata.get(key) in value
|
||||
if isinstance(value, list)
|
||||
else doc.metadata.get(key) == value
|
||||
for key, value in filter.items()
|
||||
):
|
||||
if filter_func(doc.metadata):
|
||||
filtered_indices.append(i)
|
||||
indices = np.array([filtered_indices])
|
||||
# -1 happens when not enough docs are returned.
|
||||
@ -617,7 +663,7 @@ class FAISS(VectorStore):
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and their similarity scores selected using the maximal marginal
|
||||
relevance asynchronously.
|
||||
@ -655,7 +701,7 @@ class FAISS(VectorStore):
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
@ -686,7 +732,7 @@ class FAISS(VectorStore):
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance asynchronously.
|
||||
@ -719,7 +765,7 @@ class FAISS(VectorStore):
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
@ -756,7 +802,7 @@ class FAISS(VectorStore):
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance asynchronously.
|
||||
@ -1110,7 +1156,7 @@ class FAISS(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
@ -1139,7 +1185,7 @@ class FAISS(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
|
@ -307,6 +307,9 @@ def test_faiss_mmr_with_metadatas_and_filter() -> None:
|
||||
assert len(output) == 1
|
||||
assert output[0][0] == Document(page_content="foo", metadata={"page": 1})
|
||||
assert output[0][1] == 0.0
|
||||
assert output == docsearch.max_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec, k=10, lambda_mult=0.1, filter=lambda di: di["page"] == 1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
@ -321,6 +324,12 @@ async def test_faiss_async_mmr_with_metadatas_and_filter() -> None:
|
||||
assert len(output) == 1
|
||||
assert output[0][0] == Document(page_content="foo", metadata={"page": 1})
|
||||
assert output[0][1] == 0.0
|
||||
assert (
|
||||
output
|
||||
== await docsearch.amax_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec, k=10, lambda_mult=0.1, filter=lambda di: di["page"] == 1
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
@ -336,6 +345,9 @@ def test_faiss_mmr_with_metadatas_and_list_filter() -> None:
|
||||
assert output[0][0] == Document(page_content="foo", metadata={"page": 0})
|
||||
assert output[0][1] == 0.0
|
||||
assert output[1][0] != Document(page_content="foo", metadata={"page": 0})
|
||||
assert output == docsearch.max_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec, k=10, lambda_mult=0.1, filter=lambda di: di["page"] in [0, 1, 2]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
@ -351,6 +363,11 @@ async def test_faiss_async_mmr_with_metadatas_and_list_filter() -> None:
|
||||
assert output[0][0] == Document(page_content="foo", metadata={"page": 0})
|
||||
assert output[0][1] == 0.0
|
||||
assert output[1][0] != Document(page_content="foo", metadata={"page": 0})
|
||||
assert output == (
|
||||
await docsearch.amax_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec, k=10, lambda_mult=0.1, filter=lambda di: di["page"] in [0, 1, 2]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
@ -421,7 +438,11 @@ def test_faiss_with_metadatas_and_filter() -> None:
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = docsearch.similarity_search("foo", k=1, filter={"page": 1})
|
||||
assert output == [Document(page_content="bar", metadata={"page": 1})]
|
||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||
assert output != [Document(page_content="bar", metadata={"page": 1})]
|
||||
assert output == docsearch.similarity_search(
|
||||
"foo", k=1, filter=lambda di: di["page"] == 1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
@ -444,7 +465,11 @@ async def test_faiss_async_with_metadatas_and_filter() -> None:
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = await docsearch.asimilarity_search("foo", k=1, filter={"page": 1})
|
||||
assert output == [Document(page_content="bar", metadata={"page": 1})]
|
||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||
assert output != [Document(page_content="bar", metadata={"page": 1})]
|
||||
assert output == await docsearch.asimilarity_search(
|
||||
"foo", k=1, filter=lambda di: di["page"] == 1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
@ -474,6 +499,9 @@ def test_faiss_with_metadatas_and_list_filter() -> None:
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = docsearch.similarity_search("foor", k=1, filter={"page": [0, 1, 2]})
|
||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||
assert output == docsearch.similarity_search(
|
||||
"foor", k=1, filter=lambda di: di["page"] in [0, 1, 2]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
@ -503,6 +531,9 @@ async def test_faiss_async_with_metadatas_and_list_filter() -> None:
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = await docsearch.asimilarity_search("foor", k=1, filter={"page": [0, 1, 2]})
|
||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||
assert output == await docsearch.asimilarity_search(
|
||||
"foor", k=1, filter=lambda di: di["page"] in [0, 1, 2]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("faiss")
|
||||
|
Loading…
Reference in New Issue
Block a user