mirror of
https://github.com/hwchase17/langchain.git
synced 2025-11-02 09:14:45 +00:00
Add maximal relevance search to SKLearnVectorStore (#5430)
# Add maximal relevance search to SKLearnVectorStore This PR implements the maximum relevance search in SKLearnVectorStore. Twitter handle: jtolgyesi (I submitted also the original implementation of SKLearnVectorStore) ## Before submitting Unit tests are included. Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
@@ -11,7 +11,7 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
def test_sklearn() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = SKLearnVectorStore.from_texts(texts, embedding=FakeEmbeddings())
|
||||
docsearch = SKLearnVectorStore.from_texts(texts, FakeEmbeddings())
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert len(output) == 1
|
||||
assert output[0].page_content == "foo"
|
||||
@@ -24,7 +24,7 @@ def test_sklearn_with_metadatas() -> None:
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = SKLearnVectorStore.from_texts(
|
||||
texts,
|
||||
embedding=FakeEmbeddings(),
|
||||
FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
@@ -38,7 +38,7 @@ def test_sklearn_with_metadatas_with_scores() -> None:
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = SKLearnVectorStore.from_texts(
|
||||
texts,
|
||||
embedding=FakeEmbeddings(),
|
||||
FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
)
|
||||
output = docsearch.similarity_search_with_relevance_scores("foo", k=1)
|
||||
@@ -69,8 +69,32 @@ def test_sklearn_with_persistence(tmpdir: Path) -> None:
|
||||
|
||||
# Get a new VectorStore from the persisted directory
|
||||
docsearch = SKLearnVectorStore(
|
||||
embedding=FakeEmbeddings(), persist_path=str(persist_path), serializer="json"
|
||||
FakeEmbeddings(), persist_path=str(persist_path), serializer="json"
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert len(output) == 1
|
||||
assert output[0].page_content == "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("numpy", "sklearn")
|
||||
def test_sklearn_mmr() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = SKLearnVectorStore.from_texts(texts, FakeEmbeddings())
|
||||
output = docsearch.max_marginal_relevance_search("foo", k=1, fetch_k=3)
|
||||
assert len(output) == 1
|
||||
assert output[0].page_content == "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("numpy", "sklearn")
|
||||
def test_sklearn_mmr_by_vector() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
embeddings = FakeEmbeddings()
|
||||
docsearch = SKLearnVectorStore.from_texts(texts, embeddings)
|
||||
embedded_query = embeddings.embed_query("foo")
|
||||
output = docsearch.max_marginal_relevance_search_by_vector(
|
||||
embedded_query, k=1, fetch_k=3
|
||||
)
|
||||
assert len(output) == 1
|
||||
assert output[0].page_content == "foo"
|
||||
|
||||
Reference in New Issue
Block a user