mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 22:59:05 +00:00
add with score option for max marginal relevance (#6867)
### Adding the functionality to return the scores with retrieved documents when using the max marginal relevance - Description: Add the method `max_marginal_relevance_search_with_score_by_vector` to the FAISS wrapper. Functionality operates the same as `similarity_search_with_score_by_vector` except for using the max marginal relevance retrieval framework like is used in the `max_marginal_relevance_search_by_vector` method. - Dependencies: None - Tag maintainer: @rlancemartin @eyurtsev - Twitter handle: @RianDolphin --------- Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
398e4cd2dc
commit
2e39ede848
@ -37,7 +37,7 @@ def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
|
|||||||
else:
|
else:
|
||||||
import faiss
|
import faiss
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError(
|
raise ImportError(
|
||||||
"Could not import faiss python package. "
|
"Could not import faiss python package. "
|
||||||
"Please install it with `pip install faiss` "
|
"Please install it with `pip install faiss` "
|
||||||
"or `pip install faiss-cpu` (depending on Python version)."
|
"or `pip install faiss-cpu` (depending on Python version)."
|
||||||
@ -321,6 +321,73 @@ class FAISS(VectorStore):
|
|||||||
)
|
)
|
||||||
return [doc for doc, _ in docs_and_scores]
|
return [doc for doc, _ in docs_and_scores]
|
||||||
|
|
||||||
|
def max_marginal_relevance_search_with_score_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
*,
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
filter: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Return docs and their similarity scores selected using the maximal marginal
|
||||||
|
relevance.
|
||||||
|
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: Embedding to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch before filtering to
|
||||||
|
pass to MMR algorithm.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
Returns:
|
||||||
|
List of Documents and similarity scores selected by maximal marginal
|
||||||
|
relevance and score for each.
|
||||||
|
"""
|
||||||
|
scores, indices = self.index.search(
|
||||||
|
np.array([embedding], dtype=np.float32),
|
||||||
|
fetch_k if filter is None else fetch_k * 2,
|
||||||
|
)
|
||||||
|
if filter is not None:
|
||||||
|
filtered_indices = []
|
||||||
|
for i in indices[0]:
|
||||||
|
if i == -1:
|
||||||
|
# This happens when not enough docs are returned.
|
||||||
|
continue
|
||||||
|
_id = self.index_to_docstore_id[i]
|
||||||
|
doc = self.docstore.search(_id)
|
||||||
|
if not isinstance(doc, Document):
|
||||||
|
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||||
|
if all(doc.metadata.get(key) == value for key, value in filter.items()):
|
||||||
|
filtered_indices.append(i)
|
||||||
|
indices = np.array([filtered_indices])
|
||||||
|
# -1 happens when not enough docs are returned.
|
||||||
|
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
|
||||||
|
mmr_selected = maximal_marginal_relevance(
|
||||||
|
np.array([embedding], dtype=np.float32),
|
||||||
|
embeddings,
|
||||||
|
k=k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
)
|
||||||
|
selected_indices = [indices[0][i] for i in mmr_selected]
|
||||||
|
selected_scores = [scores[0][i] for i in mmr_selected]
|
||||||
|
docs_and_scores = []
|
||||||
|
for i, score in zip(selected_indices, selected_scores):
|
||||||
|
if i == -1:
|
||||||
|
# This happens when not enough docs are returned.
|
||||||
|
continue
|
||||||
|
_id = self.index_to_docstore_id[i]
|
||||||
|
doc = self.docstore.search(_id)
|
||||||
|
if not isinstance(doc, Document):
|
||||||
|
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||||
|
docs_and_scores.append((doc, score))
|
||||||
|
return docs_and_scores
|
||||||
|
|
||||||
def max_marginal_relevance_search_by_vector(
|
def max_marginal_relevance_search_by_vector(
|
||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
@ -347,43 +414,10 @@ class FAISS(VectorStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List of Documents selected by maximal marginal relevance.
|
List of Documents selected by maximal marginal relevance.
|
||||||
"""
|
"""
|
||||||
_, indices = self.index.search(
|
docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
|
||||||
np.array([embedding], dtype=np.float32),
|
embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter
|
||||||
fetch_k if filter is None else fetch_k * 2,
|
|
||||||
)
|
)
|
||||||
if filter is not None:
|
return [doc for doc, _ in docs_and_scores]
|
||||||
filtered_indices = []
|
|
||||||
for i in indices[0]:
|
|
||||||
if i == -1:
|
|
||||||
# This happens when not enough docs are returned.
|
|
||||||
continue
|
|
||||||
_id = self.index_to_docstore_id[i]
|
|
||||||
doc = self.docstore.search(_id)
|
|
||||||
if not isinstance(doc, Document):
|
|
||||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
|
||||||
if all(doc.metadata.get(key) == value for key, value in filter.items()):
|
|
||||||
filtered_indices.append(i)
|
|
||||||
indices = np.array([filtered_indices])
|
|
||||||
# -1 happens when not enough docs are returned.
|
|
||||||
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
|
|
||||||
mmr_selected = maximal_marginal_relevance(
|
|
||||||
np.array([embedding], dtype=np.float32),
|
|
||||||
embeddings,
|
|
||||||
k=k,
|
|
||||||
lambda_mult=lambda_mult,
|
|
||||||
)
|
|
||||||
selected_indices = [indices[0][i] for i in mmr_selected]
|
|
||||||
docs = []
|
|
||||||
for i in selected_indices:
|
|
||||||
if i == -1:
|
|
||||||
# This happens when not enough docs are returned.
|
|
||||||
continue
|
|
||||||
_id = self.index_to_docstore_id[i]
|
|
||||||
doc = self.docstore.search(_id)
|
|
||||||
if not isinstance(doc, Document):
|
|
||||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
|
||||||
docs.append(doc)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def max_marginal_relevance_search(
|
def max_marginal_relevance_search(
|
||||||
self,
|
self,
|
||||||
@ -414,8 +448,8 @@ class FAISS(VectorStore):
|
|||||||
embedding = self.embedding_function(query)
|
embedding = self.embedding_function(query)
|
||||||
docs = self.max_marginal_relevance_search_by_vector(
|
docs = self.max_marginal_relevance_search_by_vector(
|
||||||
embedding,
|
embedding,
|
||||||
k,
|
k=k,
|
||||||
fetch_k,
|
fetch_k=fetch_k,
|
||||||
lambda_mult=lambda_mult,
|
lambda_mult=lambda_mult,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -46,9 +46,19 @@ def test_faiss_vector_sim() -> None:
|
|||||||
output = docsearch.similarity_search_by_vector(query_vec, k=1)
|
output = docsearch.similarity_search_by_vector(query_vec, k=1)
|
||||||
assert output == [Document(page_content="foo")]
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_faiss_mmr() -> None:
|
||||||
|
texts = ["foo", "foo", "fou", "foy"]
|
||||||
|
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||||
|
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||||
# make sure we can have k > docstore size
|
# make sure we can have k > docstore size
|
||||||
output = docsearch.max_marginal_relevance_search_by_vector(query_vec, k=10)
|
output = docsearch.max_marginal_relevance_search_with_score_by_vector(
|
||||||
|
query_vec, k=10, lambda_mult=0.1
|
||||||
|
)
|
||||||
assert len(output) == len(texts)
|
assert len(output) == len(texts)
|
||||||
|
assert output[0][0] == Document(page_content="foo")
|
||||||
|
assert output[0][1] == 0.0
|
||||||
|
assert output[1][0] != Document(page_content="foo")
|
||||||
|
|
||||||
|
|
||||||
def test_faiss_with_metadatas() -> None:
|
def test_faiss_with_metadatas() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user