mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 12:59:07 +00:00
commnity[patch]: refactor code for faiss vectorstore, update faiss vectorstore documentation (#18092)
**Description:** Refactor code of FAISS vectorcstore and update the related documentation. Details: - replace `.format()` with f-strings for strings formatting; - refactor definition of a filtering function to make code more readable and more flexible; - slightly improve efficiency of `max_marginal_relevance_search_with_score_by_vector` method by removing unnecessary looping over the same elements; - slightly improve efficiency of `delete` method by using set data structure for checking if the element was already deleted; **Issue:** fix small inconsistency in the documentation (the old example was incorrect and unappliable to faiss vectorstore) **Dependencies:** basic langchain-community dependencies and `faiss` (for CPU or for GPU) **Twitter handle:** antonenkodev
This commit is contained in:
@@ -119,9 +119,8 @@ class FAISS(VectorStore):
|
||||
and self._normalize_L2
|
||||
):
|
||||
warnings.warn(
|
||||
"Normalizing L2 is not applicable for metric type: {strategy}".format(
|
||||
strategy=self.distance_strategy
|
||||
)
|
||||
"Normalizing L2 is not applicable for "
|
||||
f"metric type: {self.distance_strategy}"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -306,24 +305,7 @@ class FAISS(VectorStore):
|
||||
docs = []
|
||||
|
||||
if filter is not None:
|
||||
if isinstance(filter, dict):
|
||||
|
||||
def filter_func(metadata): # type: ignore[no-untyped-def]
|
||||
if all(
|
||||
metadata.get(key) in value
|
||||
if isinstance(value, list)
|
||||
else metadata.get(key) == value
|
||||
for key, value in filter.items()
|
||||
):
|
||||
return True
|
||||
return False
|
||||
elif callable(filter):
|
||||
filter_func = filter
|
||||
else:
|
||||
raise ValueError(
|
||||
"filter must be a dict of metadata or "
|
||||
f"a callable, not {type(filter)}"
|
||||
)
|
||||
filter_func = self._create_filter_func(filter)
|
||||
|
||||
for j, i in enumerate(indices[0]):
|
||||
if i == -1:
|
||||
@@ -608,25 +590,8 @@ class FAISS(VectorStore):
|
||||
fetch_k if filter is None else fetch_k * 2,
|
||||
)
|
||||
if filter is not None:
|
||||
filter_func = self._create_filter_func(filter)
|
||||
filtered_indices = []
|
||||
if isinstance(filter, dict):
|
||||
|
||||
def filter_func(metadata): # type: ignore[no-untyped-def]
|
||||
if all(
|
||||
metadata.get(key) in value
|
||||
if isinstance(value, list)
|
||||
else metadata.get(key) == value
|
||||
for key, value in filter.items()
|
||||
):
|
||||
return True
|
||||
return False
|
||||
elif callable(filter):
|
||||
filter_func = filter
|
||||
else:
|
||||
raise ValueError(
|
||||
"filter must be a dict of metadata or "
|
||||
f"a callable, not {type(filter)}"
|
||||
)
|
||||
for i in indices[0]:
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
@@ -646,18 +611,18 @@ class FAISS(VectorStore):
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
selected_indices = [indices[0][i] for i in mmr_selected]
|
||||
selected_scores = [scores[0][i] for i in mmr_selected]
|
||||
|
||||
docs_and_scores = []
|
||||
for i, score in zip(selected_indices, selected_scores):
|
||||
if i == -1:
|
||||
for i in mmr_selected:
|
||||
if indices[0][i] == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
_id = self.index_to_docstore_id[i]
|
||||
_id = self.index_to_docstore_id[indices[0][i]]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
docs_and_scores.append((doc, score))
|
||||
docs_and_scores.append((doc, scores[0][i]))
|
||||
|
||||
return docs_and_scores
|
||||
|
||||
async def amax_marginal_relevance_search_with_score_by_vector(
|
||||
@@ -857,9 +822,9 @@ class FAISS(VectorStore):
|
||||
)
|
||||
|
||||
reversed_index = {id_: idx for idx, id_ in self.index_to_docstore_id.items()}
|
||||
index_to_delete = [reversed_index[id_] for id_ in ids]
|
||||
index_to_delete = {reversed_index[id_] for id_ in ids}
|
||||
|
||||
self.index.remove_ids(np.array(index_to_delete, dtype=np.int64))
|
||||
self.index.remove_ids(np.fromiter(index_to_delete, dtype=np.int64))
|
||||
self.docstore.delete(ids)
|
||||
|
||||
remaining_ids = [
|
||||
@@ -1079,12 +1044,10 @@ class FAISS(VectorStore):
|
||||
|
||||
# save index separately since it is not picklable
|
||||
faiss = dependable_faiss_import()
|
||||
faiss.write_index(
|
||||
self.index, str(path / "{index_name}.faiss".format(index_name=index_name))
|
||||
)
|
||||
faiss.write_index(self.index, str(path / f"{index_name}.faiss"))
|
||||
|
||||
# save docstore and index_to_docstore_id
|
||||
with open(path / "{index_name}.pkl".format(index_name=index_name), "wb") as f:
|
||||
with open(path / f"{index_name}.pkl", "wb") as f:
|
||||
pickle.dump((self.docstore, self.index_to_docstore_id), f)
|
||||
|
||||
@classmethod
|
||||
@@ -1127,12 +1090,10 @@ class FAISS(VectorStore):
|
||||
path = Path(folder_path)
|
||||
# load index separately since it is not picklable
|
||||
faiss = dependable_faiss_import()
|
||||
index = faiss.read_index(
|
||||
str(path / "{index_name}.faiss".format(index_name=index_name))
|
||||
)
|
||||
index = faiss.read_index(str(path / f"{index_name}.faiss"))
|
||||
|
||||
# load docstore and index_to_docstore_id
|
||||
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
|
||||
with open(path / f"{index_name}.pkl", "rb") as f:
|
||||
docstore, index_to_docstore_id = pickle.load(f)
|
||||
return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs)
|
||||
|
||||
@@ -1235,3 +1196,36 @@ class FAISS(VectorStore):
|
||||
(doc, relevance_score_fn(score)) for doc, score in docs_and_scores
|
||||
]
|
||||
return docs_and_rel_scores
|
||||
|
||||
@staticmethod
|
||||
def _create_filter_func(
|
||||
filter: Optional[Union[Callable, Dict[str, Any]]],
|
||||
) -> Callable[[Dict[str, Any]], bool]:
|
||||
"""
|
||||
Create a filter function based on the provided filter.
|
||||
|
||||
Args:
|
||||
filter: A callable or a dictionary representing the filter
|
||||
conditions for documents.
|
||||
|
||||
Returns:
|
||||
Callable[[Dict[str, Any]], bool]: A function that takes Document's metadata
|
||||
and returns True if it satisfies the filter conditions, otherwise False.
|
||||
"""
|
||||
if callable(filter):
|
||||
return filter
|
||||
|
||||
if not isinstance(filter, dict):
|
||||
raise ValueError(
|
||||
f"filter must be a dict of metadata or a callable, not {type(filter)}"
|
||||
)
|
||||
|
||||
def filter_func(metadata: Dict[str, Any]) -> bool:
|
||||
return all(
|
||||
metadata.get(key) in value
|
||||
if isinstance(value, list)
|
||||
else metadata.get(key) == value
|
||||
for key, value in filter.items() # type: ignore
|
||||
)
|
||||
|
||||
return filter_func
|
||||
|
Reference in New Issue
Block a user