Add maximal relevance search to SKLearnVectorStore (#5430)

# Add maximal relevance search to SKLearnVectorStore

This PR implements the maximum relevance search in SKLearnVectorStore. 

Twitter handle: jtolgyesi (I submitted also the original implementation
of SKLearnVectorStore)

## Before submitting

Unit tests are included.

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
Janos Tolgyesi
2023-05-31 01:13:33 +02:00
committed by GitHub
parent 8181f9e362
commit 1111f18eb4
2 changed files with 134 additions and 18 deletions

View File

@@ -11,7 +11,7 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
def test_sklearn() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
docsearch = SKLearnVectorStore.from_texts(texts, embedding=FakeEmbeddings())
docsearch = SKLearnVectorStore.from_texts(texts, FakeEmbeddings())
output = docsearch.similarity_search("foo", k=1)
assert len(output) == 1
assert output[0].page_content == "foo"
@@ -24,7 +24,7 @@ def test_sklearn_with_metadatas() -> None:
metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = SKLearnVectorStore.from_texts(
texts,
embedding=FakeEmbeddings(),
FakeEmbeddings(),
metadatas=metadatas,
)
output = docsearch.similarity_search("foo", k=1)
@@ -38,7 +38,7 @@ def test_sklearn_with_metadatas_with_scores() -> None:
metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = SKLearnVectorStore.from_texts(
texts,
embedding=FakeEmbeddings(),
FakeEmbeddings(),
metadatas=metadatas,
)
output = docsearch.similarity_search_with_relevance_scores("foo", k=1)
@@ -69,8 +69,32 @@ def test_sklearn_with_persistence(tmpdir: Path) -> None:
# Get a new VectorStore from the persisted directory
docsearch = SKLearnVectorStore(
embedding=FakeEmbeddings(), persist_path=str(persist_path), serializer="json"
FakeEmbeddings(), persist_path=str(persist_path), serializer="json"
)
output = docsearch.similarity_search("foo", k=1)
assert len(output) == 1
assert output[0].page_content == "foo"
@pytest.mark.requires("numpy", "sklearn")
def test_sklearn_mmr() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
docsearch = SKLearnVectorStore.from_texts(texts, FakeEmbeddings())
output = docsearch.max_marginal_relevance_search("foo", k=1, fetch_k=3)
assert len(output) == 1
assert output[0].page_content == "foo"
@pytest.mark.requires("numpy", "sklearn")
def test_sklearn_mmr_by_vector() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
embeddings = FakeEmbeddings()
docsearch = SKLearnVectorStore.from_texts(texts, embeddings)
embedded_query = embeddings.embed_query("foo")
output = docsearch.max_marginal_relevance_search_by_vector(
embedded_query, k=1, fetch_k=3
)
assert len(output) == 1
assert output[0].page_content == "foo"