add similarity score method to faiss (#574)

adds `similarity_search_with_score` to faiss wrapper
This commit is contained in:
Harrison Chase 2023-01-11 06:06:17 -08:00 committed by GitHub
parent 5ba46f6d0c
commit 2aa08631cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import uuid
from typing import Any, Callable, Dict, Iterable, List, Optional
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import numpy as np
@ -77,6 +77,32 @@ class FAISS(VectorStore):
self.index_to_docstore_id.update(index_to_id)
return [_id for _, _id, _ in full_info]
def similarity_search_with_score(
self, query: str, k: int = 4
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query and score for each
"""
embedding = self.embedding_function(query)
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
docs = []
for j, i in enumerate(indices[0]):
if i == -1:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
docs.append((doc, scores[0][j]))
return docs
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
"""Return docs most similar to query.
@ -87,19 +113,8 @@ class FAISS(VectorStore):
Returns:
List of Documents most similar to the query.
"""
embedding = self.embedding_function(query)
_, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
docs = []
for i in indices[0]:
if i == -1:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
docs.append(doc)
return docs
docs_and_scores = self.similarity_search_with_score(query, k)
return [doc for doc, _ in docs_and_scores]
def max_marginal_relevance_search(
self, query: str, k: int = 4, fetch_k: int = 20