mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 17:38:36 +00:00
add similarity score method to faiss (#574)
adds `similarity_search_with_score` to faiss wrapper
This commit is contained in:
parent
5ba46f6d0c
commit
2aa08631cb
@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -77,6 +77,32 @@ class FAISS(VectorStore):
|
|||||||
self.index_to_docstore_id.update(index_to_id)
|
self.index_to_docstore_id.update(index_to_id)
|
||||||
return [_id for _, _id, _ in full_info]
|
return [_id for _, _id, _ in full_info]
|
||||||
|
|
||||||
|
def similarity_search_with_score(
|
||||||
|
self, query: str, k: int = 4
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Return docs most similar to query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents most similar to the query and score for each
|
||||||
|
"""
|
||||||
|
embedding = self.embedding_function(query)
|
||||||
|
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
|
||||||
|
docs = []
|
||||||
|
for j, i in enumerate(indices[0]):
|
||||||
|
if i == -1:
|
||||||
|
# This happens when not enough docs are returned.
|
||||||
|
continue
|
||||||
|
_id = self.index_to_docstore_id[i]
|
||||||
|
doc = self.docstore.search(_id)
|
||||||
|
if not isinstance(doc, Document):
|
||||||
|
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||||
|
docs.append((doc, scores[0][j]))
|
||||||
|
return docs
|
||||||
|
|
||||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
||||||
"""Return docs most similar to query.
|
"""Return docs most similar to query.
|
||||||
|
|
||||||
@ -87,19 +113,8 @@ class FAISS(VectorStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List of Documents most similar to the query.
|
List of Documents most similar to the query.
|
||||||
"""
|
"""
|
||||||
embedding = self.embedding_function(query)
|
docs_and_scores = self.similarity_search_with_score(query, k)
|
||||||
_, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
|
return [doc for doc, _ in docs_and_scores]
|
||||||
docs = []
|
|
||||||
for i in indices[0]:
|
|
||||||
if i == -1:
|
|
||||||
# This happens when not enough docs are returned.
|
|
||||||
continue
|
|
||||||
_id = self.index_to_docstore_id[i]
|
|
||||||
doc = self.docstore.search(_id)
|
|
||||||
if not isinstance(doc, Document):
|
|
||||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
|
||||||
docs.append(doc)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def max_marginal_relevance_search(
|
def max_marginal_relevance_search(
|
||||||
self, query: str, k: int = 4, fetch_k: int = 20
|
self, query: str, k: int = 4, fetch_k: int = 20
|
||||||
|
Loading…
Reference in New Issue
Block a user