add similarity score method to faiss (#574)

adds `similarity_search_with_score` to faiss wrapper
This commit is contained in:
Harrison Chase 2023-01-11 06:06:17 -08:00 committed by GitHub
parent 5ba46f6d0c
commit 2aa08631cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import uuid import uuid
from typing import Any, Callable, Dict, Iterable, List, Optional from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import numpy as np import numpy as np
@ -77,6 +77,32 @@ class FAISS(VectorStore):
self.index_to_docstore_id.update(index_to_id) self.index_to_docstore_id.update(index_to_id)
return [_id for _, _id, _ in full_info] return [_id for _, _id, _ in full_info]
def similarity_search_with_score(
self, query: str, k: int = 4
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query and score for each
"""
embedding = self.embedding_function(query)
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
docs = []
for j, i in enumerate(indices[0]):
if i == -1:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
docs.append((doc, scores[0][j]))
return docs
def similarity_search(self, query: str, k: int = 4) -> List[Document]: def similarity_search(self, query: str, k: int = 4) -> List[Document]:
"""Return docs most similar to query. """Return docs most similar to query.
@ -87,19 +113,8 @@ class FAISS(VectorStore):
Returns: Returns:
List of Documents most similar to the query. List of Documents most similar to the query.
""" """
embedding = self.embedding_function(query) docs_and_scores = self.similarity_search_with_score(query, k)
_, indices = self.index.search(np.array([embedding], dtype=np.float32), k) return [doc for doc, _ in docs_and_scores]
docs = []
for i in indices[0]:
if i == -1:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
docs.append(doc)
return docs
def max_marginal_relevance_search( def max_marginal_relevance_search(
self, query: str, k: int = 4, fetch_k: int = 20 self, query: str, k: int = 4, fetch_k: int = 20