mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 05:25:07 +00:00
Refactor vector storage to correctly handle relevancy scores (#6570)
Description: This pull request aims to support generating the correct generic relevancy scores for different vector stores by refactoring the relevance score functions and their selection in the base class and subclasses of VectorStore. This is especially relevant with VectorStores that require a distance metric upon initialization. Note many of the current implenetations of `_similarity_search_with_relevance_scores` are not technically correct, as they just return `self.similarity_search_with_score(query, k, **kwargs)` without applying the relevant score function Also includes changes associated with: https://github.com/hwchase17/langchain/pull/6564 and https://github.com/hwchase17/langchain/pull/6494 See more indepth discussion in thread in #6494 Issue: https://github.com/hwchase17/langchain/issues/6526 https://github.com/hwchase17/langchain/issues/6481 https://github.com/hwchase17/langchain/issues/6346 Dependencies: None The changes include: - Properly handling score thresholding in FAISS `similarity_search_with_score_by_vector` for the corresponding distance metric. - Refactoring the `_similarity_search_with_relevance_scores` method in the base class and removing it from the subclasses for incorrectly implemented subclasses. - Adding a `_select_relevance_score_fn` method in the base class and implementing it in the subclasses to select the appropriate relevance score function based on the distance strategy. - Updating the `__init__` methods of the subclasses to set the `relevance_score_fn` attribute. - Removing the `_default_relevance_score_fn` function from the FAISS class and using the base class's `_euclidean_relevance_score_fn` instead. - Adding the `DistanceStrategy` enum to the `utils.py` file and updating the imports in the vector store classes. - Updating the tests to import the `DistanceStrategy` enum from the `utils.py` file. --------- Co-authored-by: Hanit <37485638+hanit-com@users.noreply.github.com>
This commit is contained in:
@@ -228,3 +228,42 @@ def test_chroma_update_document() -> None:
|
||||
]
|
||||
assert new_embedding == embedding.embed_documents([updated_content])[0]
|
||||
assert new_embedding != old_embedding
|
||||
|
||||
|
||||
def test_chroma_with_relevance_score() -> None:
|
||||
"""Test to make sure the relevance score is scaled to 0-1."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = Chroma.from_texts(
|
||||
collection_name="test_collection",
|
||||
texts=texts,
|
||||
embedding=FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
collection_metadata={"hnsw:space": "l2"},
|
||||
)
|
||||
output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
|
||||
assert output == [
|
||||
(Document(page_content="foo", metadata={"page": "0"}), 1.0),
|
||||
(Document(page_content="bar", metadata={"page": "1"}), 0.8),
|
||||
(Document(page_content="baz", metadata={"page": "2"}), 0.5),
|
||||
]
|
||||
|
||||
|
||||
def test_chroma_with_relevance_score_custom_normalization_fn() -> None:
|
||||
"""Test searching with relevance score and custom normalization function."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = Chroma.from_texts(
|
||||
collection_name="test_collection",
|
||||
texts=texts,
|
||||
embedding=FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
relevance_score_fn=lambda d: d * 0,
|
||||
collection_metadata={"hnsw:space": "l2"},
|
||||
)
|
||||
output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
|
||||
assert output == [
|
||||
(Document(page_content="foo", metadata={"page": "0"}), -0.0),
|
||||
(Document(page_content="bar", metadata={"page": "1"}), -0.0),
|
||||
(Document(page_content="baz", metadata={"page": "2"}), -0.0),
|
||||
]
|
||||
|
@@ -195,3 +195,14 @@ def test_faiss_invalid_normalize_fn() -> None:
|
||||
)
|
||||
with pytest.warns(Warning, match="scores must be between"):
|
||||
docsearch.similarity_search_with_relevance_scores("foo", k=1)
|
||||
|
||||
|
||||
def test_missing_normalize_score_fn() -> None:
|
||||
"""Test doesn't perform similarity search without a valid distance strategy."""
|
||||
with pytest.raises(ValueError):
|
||||
texts = ["foo", "bar", "baz"]
|
||||
faiss_instance = FAISS.from_texts(
|
||||
texts, FakeEmbeddings(), distance_strategy="fake"
|
||||
)
|
||||
|
||||
faiss_instance.similarity_search_with_relevance_scores("foo", k=2)
|
||||
|
@@ -184,3 +184,70 @@ def test_pgvector_with_filter_in_set() -> None:
|
||||
(Document(page_content="foo", metadata={"page": "0"}), 0.0),
|
||||
(Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406),
|
||||
]
|
||||
|
||||
|
||||
def test_pgvector_relevance_score() -> None:
|
||||
"""Test to make sure the relevance score is scaled to 0-1."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = PGVector.from_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
metadatas=metadatas,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
|
||||
output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
|
||||
assert output == [
|
||||
(Document(page_content="foo", metadata={"page": "0"}), 1.0),
|
||||
(Document(page_content="bar", metadata={"page": "1"}), 0.9996744261675065),
|
||||
(Document(page_content="baz", metadata={"page": "2"}), 0.9986996093328621),
|
||||
]
|
||||
|
||||
|
||||
def test_pgvector_retriever_search_threshold() -> None:
|
||||
"""Test using retriever for searching with threshold."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = PGVector.from_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
metadatas=metadatas,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
|
||||
retriever = docsearch.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={"k": 3, "score_threshold": 0.999},
|
||||
)
|
||||
output = retriever.get_relevant_documents("summer")
|
||||
assert output == [
|
||||
Document(page_content="foo", metadata={"page": "0"}),
|
||||
Document(page_content="bar", metadata={"page": "1"}),
|
||||
]
|
||||
|
||||
|
||||
def test_pgvector_retriever_search_threshold_custom_normalization_fn() -> None:
|
||||
"""Test searching with threshold and custom normalization function"""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = PGVector.from_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
metadatas=metadatas,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
relevance_score_fn=lambda d: d * 0,
|
||||
)
|
||||
|
||||
retriever = docsearch.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={"k": 3, "score_threshold": 0.5},
|
||||
)
|
||||
output = retriever.get_relevant_documents("foo")
|
||||
assert output == []
|
||||
|
@@ -2,6 +2,7 @@ import importlib
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
import pinecone
|
||||
@@ -154,3 +155,21 @@ class TestPinecone:
|
||||
time.sleep(20)
|
||||
index_stats = self.index.describe_index_stats()
|
||||
assert index_stats["total_vector_count"] == len(texts) * 2
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_relevance_score_bound(self, embedding_openai: OpenAIEmbeddings) -> None:
|
||||
"""Ensures all relevance scores are between 0 and 1."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = Pinecone.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
index_name=index_name,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
# wait for the index to be ready
|
||||
time.sleep(20)
|
||||
output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
|
||||
assert all(
|
||||
(1 >= score or np.isclose(score, 1)) and score >= 0 for _, score in output
|
||||
)
|
||||
|
@@ -1,6 +1,7 @@
|
||||
"""Test Qdrant functionality."""
|
||||
import tempfile
|
||||
from typing import Callable, Optional
|
||||
import numpy as np
|
||||
|
||||
import pytest
|
||||
from qdrant_client.http import models as rest
|
||||
@@ -513,3 +514,26 @@ def test_qdrant_add_texts_stores_embeddings_as_named_vectors(vector_name: str) -
|
||||
vector_name in point.vector # type: ignore[operator]
|
||||
for point in client.scroll(collection_name, with_vectors=True)[0]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
||||
def test_qdrant_similarity_search_with_relevance_scores(
|
||||
batch_size: int, content_payload_key: str, metadata_payload_key: str
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = Qdrant.from_texts(
|
||||
texts,
|
||||
ConsistentFakeEmbeddings(),
|
||||
location=":memory:",
|
||||
content_payload_key=content_payload_key,
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
|
||||
|
||||
assert all(
|
||||
(1 >= score or np.isclose(score, 1)) and score >= 0 for _, score in output
|
||||
)
|
||||
|
@@ -5,7 +5,8 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores.singlestoredb import DistanceStrategy, SingleStoreDB
|
||||
from langchain.vectorstores.singlestoredb import SingleStoreDB
|
||||
from langchain.vectorstores.utils import DistanceStrategy
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
|
||||
TEST_SINGLESTOREDB_URL = "root:pass@localhost:3306/db"
|
||||
|
Reference in New Issue
Block a user