mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
add similarity score method to faiss (#574)
adds `similarity_search_with_score` to faiss wrapper
This commit is contained in:
parent
5ba46f6d0c
commit
2aa08631cb
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -77,6 +77,32 @@ class FAISS(VectorStore):
|
||||
self.index_to_docstore_id.update(index_to_id)
|
||||
return [_id for _, _id, _ in full_info]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, query: str, k: int = 4
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query and score for each
|
||||
"""
|
||||
embedding = self.embedding_function(query)
|
||||
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
|
||||
docs = []
|
||||
for j, i in enumerate(indices[0]):
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
docs.append((doc, scores[0][j]))
|
||||
return docs
|
||||
|
||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
@ -87,19 +113,8 @@ class FAISS(VectorStore):
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
embedding = self.embedding_function(query)
|
||||
_, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
|
||||
docs = []
|
||||
for i in indices[0]:
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
docs.append(doc)
|
||||
return docs
|
||||
docs_and_scores = self.similarity_search_with_score(query, k)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self, query: str, k: int = 4, fetch_k: int = 20
|
||||
|
Loading…
Reference in New Issue
Block a user