mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 03:01:29 +00:00
Add similarity_search_with_normalized_similarities (#2916)
Add a method that exposes a similarity search with corresponding normalized similarity scores. Implement only for FAISS now. ### Motivation: Some memory definitions combine `relevance` with other scores, like recency , importance, etc. While many (but not all) of the `VectorStore`'s expose a `similarity_search_with_score` method, they don't all interpret the units of that score (depends on the distance metric and whether or not the the embeddings are normalized). This PR proposes a `similarity_search_with_normalized_similarities` method that lets consumers of the vector store not have to worry about the metric and embedding scale. *Most providers default to euclidean distance, with Pinecone being one exception (defaults to cosine _similarity_).* --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
b9db20481f
commit
4ffc58e07b
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, root_validator
|
from pydantic import BaseModel, Field, root_validator
|
||||||
|
|
||||||
@ -81,6 +81,41 @@ class VectorStore(ABC):
|
|||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Return docs most similar to query."""
|
"""Return docs most similar to query."""
|
||||||
|
|
||||||
|
def similarity_search_with_relevance_scores(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Return docs and relevance scores in the range [0, 1].
|
||||||
|
|
||||||
|
0 is dissimilar, 1 is most similar.
|
||||||
|
"""
|
||||||
|
docs_and_similarities = self._similarity_search_with_relevance_scores(
|
||||||
|
query, k=k, **kwargs
|
||||||
|
)
|
||||||
|
if any(
|
||||||
|
similarity < 0.0 or similarity > 1.0
|
||||||
|
for _, similarity in docs_and_similarities
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Relevance scores must be between"
|
||||||
|
f" 0 and 1, got {docs_and_similarities}"
|
||||||
|
)
|
||||||
|
return docs_and_similarities
|
||||||
|
|
||||||
|
def _similarity_search_with_relevance_scores(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Return docs and relevance scores, normalized on a scale from 0 to 1.
|
||||||
|
|
||||||
|
0 is dissimilar, 1 is most similar.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def asimilarity_search(
|
async def asimilarity_search(
|
||||||
self, query: str, k: int = 4, **kwargs: Any
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Wrapper around FAISS vector database."""
|
"""Wrapper around FAISS vector database."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
import pickle
|
import pickle
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -29,6 +30,20 @@ def dependable_faiss_import() -> Any:
|
|||||||
return faiss
|
return faiss
|
||||||
|
|
||||||
|
|
||||||
|
def _default_relevance_score_fn(score: float) -> float:
|
||||||
|
"""Return a similarity score on a scale [0, 1]."""
|
||||||
|
# The 'correct' relevance function
|
||||||
|
# may differ depending on a few things, including:
|
||||||
|
# - the distance / similarity metric used by the VectorStore
|
||||||
|
# - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||||
|
# - embedding dimensionality
|
||||||
|
# - etc.
|
||||||
|
# This function converts the euclidean norm of normalized embeddings
|
||||||
|
# (0 is most similar, sqrt(2) most dissimilar)
|
||||||
|
# to a similarity function (0 to 1)
|
||||||
|
return 1.0 - score / math.sqrt(2)
|
||||||
|
|
||||||
|
|
||||||
class FAISS(VectorStore):
|
class FAISS(VectorStore):
|
||||||
"""Wrapper around FAISS vector database.
|
"""Wrapper around FAISS vector database.
|
||||||
|
|
||||||
@ -48,12 +63,16 @@ class FAISS(VectorStore):
|
|||||||
index: Any,
|
index: Any,
|
||||||
docstore: Docstore,
|
docstore: Docstore,
|
||||||
index_to_docstore_id: Dict[int, str],
|
index_to_docstore_id: Dict[int, str],
|
||||||
|
relevance_score_fn: Optional[
|
||||||
|
Callable[[float], float]
|
||||||
|
] = _default_relevance_score_fn,
|
||||||
):
|
):
|
||||||
"""Initialize with necessary components."""
|
"""Initialize with necessary components."""
|
||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
self.index = index
|
self.index = index
|
||||||
self.docstore = docstore
|
self.docstore = docstore
|
||||||
self.index_to_docstore_id = index_to_docstore_id
|
self.index_to_docstore_id = index_to_docstore_id
|
||||||
|
self.relevance_score_fn = relevance_score_fn
|
||||||
|
|
||||||
def __add(
|
def __add(
|
||||||
self,
|
self,
|
||||||
@ -318,7 +337,7 @@ class FAISS(VectorStore):
|
|||||||
docstore = InMemoryDocstore(
|
docstore = InMemoryDocstore(
|
||||||
{index_to_id[i]: doc for i, doc in enumerate(documents)}
|
{index_to_id[i]: doc for i, doc in enumerate(documents)}
|
||||||
)
|
)
|
||||||
return cls(embedding.embed_query, index, docstore, index_to_id)
|
return cls(embedding.embed_query, index, docstore, index_to_id, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
@ -346,7 +365,13 @@ class FAISS(VectorStore):
|
|||||||
faiss = FAISS.from_texts(texts, embeddings)
|
faiss = FAISS.from_texts(texts, embeddings)
|
||||||
"""
|
"""
|
||||||
embeddings = embedding.embed_documents(texts)
|
embeddings = embedding.embed_documents(texts)
|
||||||
return cls.__from(texts, embeddings, embedding, metadatas, **kwargs)
|
return cls.__from(
|
||||||
|
texts,
|
||||||
|
embeddings,
|
||||||
|
embedding,
|
||||||
|
metadatas,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_embeddings(
|
def from_embeddings(
|
||||||
@ -375,7 +400,13 @@ class FAISS(VectorStore):
|
|||||||
"""
|
"""
|
||||||
texts = [t[0] for t in text_embeddings]
|
texts = [t[0] for t in text_embeddings]
|
||||||
embeddings = [t[1] for t in text_embeddings]
|
embeddings = [t[1] for t in text_embeddings]
|
||||||
return cls.__from(texts, embeddings, embedding, metadatas, **kwargs)
|
return cls.__from(
|
||||||
|
texts,
|
||||||
|
embeddings,
|
||||||
|
embedding,
|
||||||
|
metadatas,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def save_local(self, folder_path: str, index_name: str = "index") -> None:
|
def save_local(self, folder_path: str, index_name: str = "index") -> None:
|
||||||
"""Save FAISS index, docstore, and index_to_docstore_id to disk.
|
"""Save FAISS index, docstore, and index_to_docstore_id to disk.
|
||||||
@ -421,3 +452,18 @@ class FAISS(VectorStore):
|
|||||||
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
|
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
|
||||||
docstore, index_to_docstore_id = pickle.load(f)
|
docstore, index_to_docstore_id = pickle.load(f)
|
||||||
return cls(embeddings.embed_query, index, docstore, index_to_docstore_id)
|
return cls(embeddings.embed_query, index, docstore, index_to_docstore_id)
|
||||||
|
|
||||||
|
def _similarity_search_with_relevance_scores(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Return docs and their similarity scores on a scale from 0 to 1."""
|
||||||
|
if self.relevance_score_fn is None:
|
||||||
|
raise ValueError(
|
||||||
|
"normalize_score_fn must be provided to"
|
||||||
|
" FAISS constructor to normalize scores"
|
||||||
|
)
|
||||||
|
docs_and_scores = self.similarity_search_with_score(query, k=k)
|
||||||
|
return [(doc, self.relevance_score_fn(score)) for doc, score in docs_and_scores]
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Test FAISS functionality."""
|
"""Test FAISS functionality."""
|
||||||
|
import math
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -109,3 +110,37 @@ def test_faiss_local_save_load() -> None:
|
|||||||
docsearch.save_local(temp_file.name)
|
docsearch.save_local(temp_file.name)
|
||||||
new_docsearch = FAISS.load_local(temp_file.name, FakeEmbeddings())
|
new_docsearch = FAISS.load_local(temp_file.name, FakeEmbeddings())
|
||||||
assert new_docsearch.index is not None
|
assert new_docsearch.index is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_faiss_similarity_search_with_relevance_scores() -> None:
|
||||||
|
"""Test the similarity search with normalized similarities."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = FAISS.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
normalize_score_fn=lambda score: 1.0 - score / math.sqrt(2),
|
||||||
|
)
|
||||||
|
outputs = docsearch.similarity_search_with_relevance_scores("foo", k=1)
|
||||||
|
output, score = outputs[0]
|
||||||
|
assert output == Document(page_content="foo")
|
||||||
|
assert score == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_faiss_invalid_normalize_fn() -> None:
|
||||||
|
"""Test the similarity search with normalized similarities."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = FAISS.from_texts(
|
||||||
|
texts, FakeEmbeddings(), normalize_score_fn=lambda _: 2.0
|
||||||
|
)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Normalized similarity scores must be between 0 and 1"
|
||||||
|
):
|
||||||
|
docsearch.similarity_search_with_relevance_scores("foo", k=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_normalize_score_fn() -> None:
|
||||||
|
"""Test doesn't perform similarity search without a normalize score function."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
faiss_instance = FAISS.from_texts(texts, FakeEmbeddings())
|
||||||
|
faiss_instance.similarity_search_with_relevance_scores("foo", k=2)
|
||||||
|
Loading…
Reference in New Issue
Block a user