mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 06:40:04 +00:00
add with score option for max marginal relevance (#6867)
### Adding the functionality to return the scores with retrieved documents when using the max marginal relevance - Description: Add the method `max_marginal_relevance_search_with_score_by_vector` to the FAISS wrapper. Functionality operates the same as `similarity_search_with_score_by_vector` except for using the max marginal relevance retrieval framework like is used in the `max_marginal_relevance_search_by_vector` method. - Dependencies: None - Tag maintainer: @rlancemartin @eyurtsev - Twitter handle: @RianDolphin --------- Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
398e4cd2dc
commit
2e39ede848
@ -37,7 +37,7 @@ def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
|
||||
else:
|
||||
import faiss
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import faiss python package. "
|
||||
"Please install it with `pip install faiss` "
|
||||
"or `pip install faiss-cpu` (depending on Python version)."
|
||||
@ -321,6 +321,73 @@ class FAISS(VectorStore):
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def max_marginal_relevance_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
*,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and their similarity scores selected using the maximal marginal
|
||||
relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch before filtering to
|
||||
pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents and similarity scores selected by maximal marginal
|
||||
relevance and score for each.
|
||||
"""
|
||||
scores, indices = self.index.search(
|
||||
np.array([embedding], dtype=np.float32),
|
||||
fetch_k if filter is None else fetch_k * 2,
|
||||
)
|
||||
if filter is not None:
|
||||
filtered_indices = []
|
||||
for i in indices[0]:
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
if all(doc.metadata.get(key) == value for key, value in filter.items()):
|
||||
filtered_indices.append(i)
|
||||
indices = np.array([filtered_indices])
|
||||
# -1 happens when not enough docs are returned.
|
||||
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array([embedding], dtype=np.float32),
|
||||
embeddings,
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
selected_indices = [indices[0][i] for i in mmr_selected]
|
||||
selected_scores = [scores[0][i] for i in mmr_selected]
|
||||
docs_and_scores = []
|
||||
for i, score in zip(selected_indices, selected_scores):
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
docs_and_scores.append((doc, score))
|
||||
return docs_and_scores
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
@ -347,43 +414,10 @@ class FAISS(VectorStore):
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
_, indices = self.index.search(
|
||||
np.array([embedding], dtype=np.float32),
|
||||
fetch_k if filter is None else fetch_k * 2,
|
||||
docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
|
||||
embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter
|
||||
)
|
||||
if filter is not None:
|
||||
filtered_indices = []
|
||||
for i in indices[0]:
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
if all(doc.metadata.get(key) == value for key, value in filter.items()):
|
||||
filtered_indices.append(i)
|
||||
indices = np.array([filtered_indices])
|
||||
# -1 happens when not enough docs are returned.
|
||||
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array([embedding], dtype=np.float32),
|
||||
embeddings,
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
selected_indices = [indices[0][i] for i in mmr_selected]
|
||||
docs = []
|
||||
for i in selected_indices:
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
docs.append(doc)
|
||||
return docs
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
@ -414,8 +448,8 @@ class FAISS(VectorStore):
|
||||
embedding = self.embedding_function(query)
|
||||
docs = self.max_marginal_relevance_search_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
fetch_k,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
|
@ -46,9 +46,19 @@ def test_faiss_vector_sim() -> None:
|
||||
output = docsearch.similarity_search_by_vector(query_vec, k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
def test_faiss_mmr() -> None:
|
||||
texts = ["foo", "foo", "fou", "foy"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
# make sure we can have k > docstore size
|
||||
output = docsearch.max_marginal_relevance_search_by_vector(query_vec, k=10)
|
||||
output = docsearch.max_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec, k=10, lambda_mult=0.1
|
||||
)
|
||||
assert len(output) == len(texts)
|
||||
assert output[0][0] == Document(page_content="foo")
|
||||
assert output[0][1] == 0.0
|
||||
assert output[1][0] != Document(page_content="foo")
|
||||
|
||||
|
||||
def test_faiss_with_metadatas() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user