diff --git a/langchain/vectorstores/faiss.py b/langchain/vectorstores/faiss.py index 42d9e25a05f..0f2f36c6b0d 100644 --- a/langchain/vectorstores/faiss.py +++ b/langchain/vectorstores/faiss.py @@ -2,7 +2,7 @@ from __future__ import annotations import uuid -from typing import Any, Callable, Dict, Iterable, List, Optional +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import numpy as np @@ -77,6 +77,32 @@ class FAISS(VectorStore): self.index_to_docstore_id.update(index_to_id) return [_id for _, _id, _ in full_info] + def similarity_search_with_score( + self, query: str, k: int = 4 + ) -> List[Tuple[Document, float]]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + + Returns: + List of Documents most similar to the query and score for each + """ + embedding = self.embedding_function(query) + scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) + docs = [] + for j, i in enumerate(indices[0]): + if i == -1: + # This happens when not enough docs are returned. + continue + _id = self.index_to_docstore_id[i] + doc = self.docstore.search(_id) + if not isinstance(doc, Document): + raise ValueError(f"Could not find document for id {_id}, got {doc}") + docs.append((doc, scores[0][j])) + return docs + def similarity_search(self, query: str, k: int = 4) -> List[Document]: """Return docs most similar to query. @@ -87,19 +113,8 @@ class FAISS(VectorStore): Returns: List of Documents most similar to the query. """ - embedding = self.embedding_function(query) - _, indices = self.index.search(np.array([embedding], dtype=np.float32), k) - docs = [] - for i in indices[0]: - if i == -1: - # This happens when not enough docs are returned. - continue - _id = self.index_to_docstore_id[i] - doc = self.docstore.search(_id) - if not isinstance(doc, Document): - raise ValueError(f"Could not find document for id {_id}, got {doc}") - docs.append(doc) - return docs + docs_and_scores = self.similarity_search_with_score(query, k) + return [doc for doc, _ in docs_and_scores] def max_marginal_relevance_search( self, query: str, k: int = 4, fetch_k: int = 20