commnity[patch]: refactor code for faiss vectorstore, update faiss vectorstore documentation (#18092)

**Description:** Refactor code of FAISS vectorcstore and update the
related documentation.
Details: 
 - replace `.format()` with f-strings for strings formatting;
- refactor definition of a filtering function to make code more readable
and more flexible;
- slightly improve efficiency of
`max_marginal_relevance_search_with_score_by_vector` method by removing
unnecessary looping over the same elements;
- slightly improve efficiency of `delete` method by using set data
structure for checking if the element was already deleted;

**Issue:** fix small inconsistency in the documentation (the old example
was incorrect and unappliable to faiss vectorstore)

**Dependencies:** basic langchain-community dependencies and `faiss`
(for CPU or for GPU)

**Twitter handle:** antonenkodev
This commit is contained in:
Tymofii
2024-03-12 05:33:03 +00:00
committed by GitHub
parent acf1ecc081
commit 0bec1f6877
2 changed files with 91 additions and 121 deletions

View File

@@ -119,9 +119,8 @@ class FAISS(VectorStore):
and self._normalize_L2
):
warnings.warn(
"Normalizing L2 is not applicable for metric type: {strategy}".format(
strategy=self.distance_strategy
)
"Normalizing L2 is not applicable for "
f"metric type: {self.distance_strategy}"
)
@property
@@ -306,24 +305,7 @@ class FAISS(VectorStore):
docs = []
if filter is not None:
if isinstance(filter, dict):
def filter_func(metadata): # type: ignore[no-untyped-def]
if all(
metadata.get(key) in value
if isinstance(value, list)
else metadata.get(key) == value
for key, value in filter.items()
):
return True
return False
elif callable(filter):
filter_func = filter
else:
raise ValueError(
"filter must be a dict of metadata or "
f"a callable, not {type(filter)}"
)
filter_func = self._create_filter_func(filter)
for j, i in enumerate(indices[0]):
if i == -1:
@@ -608,25 +590,8 @@ class FAISS(VectorStore):
fetch_k if filter is None else fetch_k * 2,
)
if filter is not None:
filter_func = self._create_filter_func(filter)
filtered_indices = []
if isinstance(filter, dict):
def filter_func(metadata): # type: ignore[no-untyped-def]
if all(
metadata.get(key) in value
if isinstance(value, list)
else metadata.get(key) == value
for key, value in filter.items()
):
return True
return False
elif callable(filter):
filter_func = filter
else:
raise ValueError(
"filter must be a dict of metadata or "
f"a callable, not {type(filter)}"
)
for i in indices[0]:
if i == -1:
# This happens when not enough docs are returned.
@@ -646,18 +611,18 @@ class FAISS(VectorStore):
k=k,
lambda_mult=lambda_mult,
)
selected_indices = [indices[0][i] for i in mmr_selected]
selected_scores = [scores[0][i] for i in mmr_selected]
docs_and_scores = []
for i, score in zip(selected_indices, selected_scores):
if i == -1:
for i in mmr_selected:
if indices[0][i] == -1:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
_id = self.index_to_docstore_id[indices[0][i]]
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
docs_and_scores.append((doc, score))
docs_and_scores.append((doc, scores[0][i]))
return docs_and_scores
async def amax_marginal_relevance_search_with_score_by_vector(
@@ -857,9 +822,9 @@ class FAISS(VectorStore):
)
reversed_index = {id_: idx for idx, id_ in self.index_to_docstore_id.items()}
index_to_delete = [reversed_index[id_] for id_ in ids]
index_to_delete = {reversed_index[id_] for id_ in ids}
self.index.remove_ids(np.array(index_to_delete, dtype=np.int64))
self.index.remove_ids(np.fromiter(index_to_delete, dtype=np.int64))
self.docstore.delete(ids)
remaining_ids = [
@@ -1079,12 +1044,10 @@ class FAISS(VectorStore):
# save index separately since it is not picklable
faiss = dependable_faiss_import()
faiss.write_index(
self.index, str(path / "{index_name}.faiss".format(index_name=index_name))
)
faiss.write_index(self.index, str(path / f"{index_name}.faiss"))
# save docstore and index_to_docstore_id
with open(path / "{index_name}.pkl".format(index_name=index_name), "wb") as f:
with open(path / f"{index_name}.pkl", "wb") as f:
pickle.dump((self.docstore, self.index_to_docstore_id), f)
@classmethod
@@ -1127,12 +1090,10 @@ class FAISS(VectorStore):
path = Path(folder_path)
# load index separately since it is not picklable
faiss = dependable_faiss_import()
index = faiss.read_index(
str(path / "{index_name}.faiss".format(index_name=index_name))
)
index = faiss.read_index(str(path / f"{index_name}.faiss"))
# load docstore and index_to_docstore_id
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
with open(path / f"{index_name}.pkl", "rb") as f:
docstore, index_to_docstore_id = pickle.load(f)
return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs)
@@ -1235,3 +1196,36 @@ class FAISS(VectorStore):
(doc, relevance_score_fn(score)) for doc, score in docs_and_scores
]
return docs_and_rel_scores
@staticmethod
def _create_filter_func(
filter: Optional[Union[Callable, Dict[str, Any]]],
) -> Callable[[Dict[str, Any]], bool]:
"""
Create a filter function based on the provided filter.
Args:
filter: A callable or a dictionary representing the filter
conditions for documents.
Returns:
Callable[[Dict[str, Any]], bool]: A function that takes Document's metadata
and returns True if it satisfies the filter conditions, otherwise False.
"""
if callable(filter):
return filter
if not isinstance(filter, dict):
raise ValueError(
f"filter must be a dict of metadata or a callable, not {type(filter)}"
)
def filter_func(metadata: Dict[str, Any]) -> bool:
return all(
metadata.get(key) in value
if isinstance(value, list)
else metadata.get(key) == value
for key, value in filter.items() # type: ignore
)
return filter_func