add with score option for max marginal relevance (#6867)

### Adding the functionality to return the scores with retrieved
documents when using the max marginal relevance
- Description: Add the method
`max_marginal_relevance_search_with_score_by_vector` to the FAISS
wrapper. Functionality operates the same as
`similarity_search_with_score_by_vector` except for using the max
marginal relevance retrieval framework like is used in the
`max_marginal_relevance_search_by_vector` method.
  - Dependencies: None
  - Tag maintainer: @rlancemartin @eyurtsev 
  - Twitter handle: @RianDolphin

---------

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
Rian Dolphin 2023-06-29 08:00:34 +03:00 committed by GitHub
parent 398e4cd2dc
commit 2e39ede848
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 84 additions and 40 deletions

View File

@ -37,7 +37,7 @@ def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
else:
import faiss
except ImportError:
raise ValueError(
raise ImportError(
"Could not import faiss python package. "
"Please install it with `pip install faiss` "
"or `pip install faiss-cpu` (depending on Python version)."
@ -321,6 +321,73 @@ class FAISS(VectorStore):
)
return [doc for doc, _ in docs_and_scores]
def max_marginal_relevance_search_with_score_by_vector(
self,
embedding: List[float],
*,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float]]:
"""Return docs and their similarity scores selected using the maximal marginal
relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch before filtering to
pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents and similarity scores selected by maximal marginal
relevance and score for each.
"""
scores, indices = self.index.search(
np.array([embedding], dtype=np.float32),
fetch_k if filter is None else fetch_k * 2,
)
if filter is not None:
filtered_indices = []
for i in indices[0]:
if i == -1:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
if all(doc.metadata.get(key) == value for key, value in filter.items()):
filtered_indices.append(i)
indices = np.array([filtered_indices])
# -1 happens when not enough docs are returned.
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
mmr_selected = maximal_marginal_relevance(
np.array([embedding], dtype=np.float32),
embeddings,
k=k,
lambda_mult=lambda_mult,
)
selected_indices = [indices[0][i] for i in mmr_selected]
selected_scores = [scores[0][i] for i in mmr_selected]
docs_and_scores = []
for i, score in zip(selected_indices, selected_scores):
if i == -1:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
docs_and_scores.append((doc, score))
return docs_and_scores
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
@ -347,43 +414,10 @@ class FAISS(VectorStore):
Returns:
List of Documents selected by maximal marginal relevance.
"""
_, indices = self.index.search(
np.array([embedding], dtype=np.float32),
fetch_k if filter is None else fetch_k * 2,
docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter
)
if filter is not None:
filtered_indices = []
for i in indices[0]:
if i == -1:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
if all(doc.metadata.get(key) == value for key, value in filter.items()):
filtered_indices.append(i)
indices = np.array([filtered_indices])
# -1 happens when not enough docs are returned.
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
mmr_selected = maximal_marginal_relevance(
np.array([embedding], dtype=np.float32),
embeddings,
k=k,
lambda_mult=lambda_mult,
)
selected_indices = [indices[0][i] for i in mmr_selected]
docs = []
for i in selected_indices:
if i == -1:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
docs.append(doc)
return docs
return [doc for doc, _ in docs_and_scores]
def max_marginal_relevance_search(
self,
@ -414,8 +448,8 @@ class FAISS(VectorStore):
embedding = self.embedding_function(query)
docs = self.max_marginal_relevance_search_by_vector(
embedding,
k,
fetch_k,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
filter=filter,
**kwargs,

View File

@ -46,9 +46,19 @@ def test_faiss_vector_sim() -> None:
output = docsearch.similarity_search_by_vector(query_vec, k=1)
assert output == [Document(page_content="foo")]
def test_faiss_mmr() -> None:
texts = ["foo", "foo", "fou", "foy"]
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
query_vec = FakeEmbeddings().embed_query(text="foo")
# make sure we can have k > docstore size
output = docsearch.max_marginal_relevance_search_by_vector(query_vec, k=10)
output = docsearch.max_marginal_relevance_search_with_score_by_vector(
query_vec, k=10, lambda_mult=0.1
)
assert len(output) == len(texts)
assert output[0][0] == Document(page_content="foo")
assert output[0][1] == 0.0
assert output[1][0] != Document(page_content="foo")
def test_faiss_with_metadatas() -> None: