Merge branch 'vwp/similarity_search_with_distances' into vwp/characters

This commit is contained in:
vowelparrot
2023-04-15 18:50:12 -07:00
3 changed files with 29 additions and 12 deletions

View File

@@ -81,17 +81,17 @@ class VectorStore(ABC):
) -> List[Document]:
"""Return docs most similar to query."""
def similarity_search_with_normalized_similarities(
def similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and similarity scores, normalized on a scale from 0 to 1.
"""Return docs and relevance scores in the range [0, 1].
0 is dissimilar, 1 is most similar.
"""
docs_and_similarities = self._similarity_search_with_normalized_similarities(
docs_and_similarities = self._similarity_search_with_relevance_scores(
query, k=k, **kwargs
)
if any(
@@ -99,18 +99,18 @@ class VectorStore(ABC):
for _, similarity in docs_and_similarities
):
raise ValueError(
"Normalized similarity scores must be between"
"Relevance scores must be between"
f" 0 and 1, got {docs_and_similarities}"
)
return docs_and_similarities
def _similarity_search_with_normalized_similarities(
def _similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and similarity scores, normalized on a scale from 0 to 1.
"""Return docs and relevance scores, normalized on a scale from 0 to 1.
0 is dissimilar, 1 is most similar.
"""

View File

@@ -1,6 +1,7 @@
"""Wrapper around FAISS vector database."""
from __future__ import annotations
import math
import pickle
import uuid
from pathlib import Path
@@ -29,6 +30,20 @@ def dependable_faiss_import() -> Any:
return faiss
def _default_normalize_score_fn(score: float) -> float:
"""Return a similarity score on a scale [0, 1]."""
# The 'correct' normalization function
# may differ depending on a few things, including:
# - the distance / similarity metric used by the VectorStore
# - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
# - embedding dimensionality
# - etc.
# This function converts the euclidean norm of normalized embeddings
# (0 is most similar, sqrt(2) most dissimilar)
# to a similarity function (0 to 1)
return 1.0 - score / math.sqrt(2)
class FAISS(VectorStore):
"""Wrapper around FAISS vector database.
@@ -48,7 +63,9 @@ class FAISS(VectorStore):
index: Any,
docstore: Docstore,
index_to_docstore_id: Dict[int, str],
normalize_score_fn: Optional[Callable[[float], float]] = None,
normalize_score_fn: Optional[
Callable[[float], float]
] = _default_normalize_score_fn,
):
"""Initialize with necessary components."""
self.embedding_function = embedding_function
@@ -424,7 +441,7 @@ class FAISS(VectorStore):
docstore, index_to_docstore_id = pickle.load(f)
return cls(embeddings.embed_query, index, docstore, index_to_docstore_id)
def _similarity_search_with_normalized_similarities(
def _similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,

View File

@@ -112,7 +112,7 @@ def test_faiss_local_save_load() -> None:
assert new_docsearch.index is not None
def test_faiss_similarity_search_with_normalized_similarities() -> None:
def test_faiss_similarity_search_with_relevance_scores() -> None:
"""Test the similarity search with normalized similarities."""
texts = ["foo", "bar", "baz"]
docsearch = FAISS.from_texts(
@@ -120,7 +120,7 @@ def test_faiss_similarity_search_with_normalized_similarities() -> None:
FakeEmbeddings(),
normalize_score_fn=lambda score: 1.0 - score / math.sqrt(2),
)
outputs = docsearch.similarity_search_with_normalized_similarities("foo", k=1)
outputs = docsearch.similarity_search_with_relevance_scores("foo", k=1)
output, score = outputs[0]
assert output == Document(page_content="foo")
assert score == 1.0
@@ -135,7 +135,7 @@ def test_faiss_invalid_normalize_fn() -> None:
with pytest.raises(
ValueError, match="Normalized similarity scores must be between 0 and 1"
):
docsearch.similarity_search_with_normalized_similarities("foo", k=1)
docsearch.similarity_search_with_relevance_scores("foo", k=1)
def test_missing_normalize_score_fn() -> None:
@@ -143,4 +143,4 @@ def test_missing_normalize_score_fn() -> None:
with pytest.raises(ValueError):
texts = ["foo", "bar", "baz"]
faiss_instance = FAISS.from_texts(texts, FakeEmbeddings())
faiss_instance.similarity_search_with_normalized_similarities("foo", k=2)
faiss_instance.similarity_search_with_relevance_scores("foo", k=2)