diff --git a/langchain/vectorstores/faiss.py b/langchain/vectorstores/faiss.py index 1489bf8ff06..ce633789ccd 100644 --- a/langchain/vectorstores/faiss.py +++ b/langchain/vectorstores/faiss.py @@ -37,7 +37,7 @@ def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any: else: import faiss except ImportError: - raise ValueError( + raise ImportError( "Could not import faiss python package. " "Please install it with `pip install faiss` " "or `pip install faiss-cpu` (depending on Python version)." @@ -321,6 +321,73 @@ class FAISS(VectorStore): ) return [doc for doc, _ in docs_and_scores] + def max_marginal_relevance_search_with_score_by_vector( + self, + embedding: List[float], + *, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, Any]] = None, + ) -> List[Tuple[Document, float]]: + """Return docs and their similarity scores selected using the maximal marginal + relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch before filtering to + pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + Returns: + List of Documents and similarity scores selected by maximal marginal + relevance and score for each. + """ + scores, indices = self.index.search( + np.array([embedding], dtype=np.float32), + fetch_k if filter is None else fetch_k * 2, + ) + if filter is not None: + filtered_indices = [] + for i in indices[0]: + if i == -1: + # This happens when not enough docs are returned. + continue + _id = self.index_to_docstore_id[i] + doc = self.docstore.search(_id) + if not isinstance(doc, Document): + raise ValueError(f"Could not find document for id {_id}, got {doc}") + if all(doc.metadata.get(key) == value for key, value in filter.items()): + filtered_indices.append(i) + indices = np.array([filtered_indices]) + # -1 happens when not enough docs are returned. + embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1] + mmr_selected = maximal_marginal_relevance( + np.array([embedding], dtype=np.float32), + embeddings, + k=k, + lambda_mult=lambda_mult, + ) + selected_indices = [indices[0][i] for i in mmr_selected] + selected_scores = [scores[0][i] for i in mmr_selected] + docs_and_scores = [] + for i, score in zip(selected_indices, selected_scores): + if i == -1: + # This happens when not enough docs are returned. + continue + _id = self.index_to_docstore_id[i] + doc = self.docstore.search(_id) + if not isinstance(doc, Document): + raise ValueError(f"Could not find document for id {_id}, got {doc}") + docs_and_scores.append((doc, score)) + return docs_and_scores + def max_marginal_relevance_search_by_vector( self, embedding: List[float], @@ -347,43 +414,10 @@ class FAISS(VectorStore): Returns: List of Documents selected by maximal marginal relevance. """ - _, indices = self.index.search( - np.array([embedding], dtype=np.float32), - fetch_k if filter is None else fetch_k * 2, + docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( + embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter ) - if filter is not None: - filtered_indices = [] - for i in indices[0]: - if i == -1: - # This happens when not enough docs are returned. - continue - _id = self.index_to_docstore_id[i] - doc = self.docstore.search(_id) - if not isinstance(doc, Document): - raise ValueError(f"Could not find document for id {_id}, got {doc}") - if all(doc.metadata.get(key) == value for key, value in filter.items()): - filtered_indices.append(i) - indices = np.array([filtered_indices]) - # -1 happens when not enough docs are returned. - embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1] - mmr_selected = maximal_marginal_relevance( - np.array([embedding], dtype=np.float32), - embeddings, - k=k, - lambda_mult=lambda_mult, - ) - selected_indices = [indices[0][i] for i in mmr_selected] - docs = [] - for i in selected_indices: - if i == -1: - # This happens when not enough docs are returned. - continue - _id = self.index_to_docstore_id[i] - doc = self.docstore.search(_id) - if not isinstance(doc, Document): - raise ValueError(f"Could not find document for id {_id}, got {doc}") - docs.append(doc) - return docs + return [doc for doc, _ in docs_and_scores] def max_marginal_relevance_search( self, @@ -414,8 +448,8 @@ class FAISS(VectorStore): embedding = self.embedding_function(query) docs = self.max_marginal_relevance_search_by_vector( embedding, - k, - fetch_k, + k=k, + fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter, **kwargs, diff --git a/tests/integration_tests/vectorstores/test_faiss.py b/tests/integration_tests/vectorstores/test_faiss.py index 37a66e8eb5b..270907461de 100644 --- a/tests/integration_tests/vectorstores/test_faiss.py +++ b/tests/integration_tests/vectorstores/test_faiss.py @@ -46,9 +46,19 @@ def test_faiss_vector_sim() -> None: output = docsearch.similarity_search_by_vector(query_vec, k=1) assert output == [Document(page_content="foo")] + +def test_faiss_mmr() -> None: + texts = ["foo", "foo", "fou", "foy"] + docsearch = FAISS.from_texts(texts, FakeEmbeddings()) + query_vec = FakeEmbeddings().embed_query(text="foo") # make sure we can have k > docstore size - output = docsearch.max_marginal_relevance_search_by_vector(query_vec, k=10) + output = docsearch.max_marginal_relevance_search_with_score_by_vector( + query_vec, k=10, lambda_mult=0.1 + ) assert len(output) == len(texts) + assert output[0][0] == Document(page_content="foo") + assert output[0][1] == 0.0 + assert output[1][0] != Document(page_content="foo") def test_faiss_with_metadatas() -> None: