Add mmr to neo4j vector (#25765)

This commit is contained in:
Tomaz Bratanic
2024-08-27 14:55:19 +02:00
committed by GitHub
parent 995305fdd5
commit f359e6b0a5
2 changed files with 148 additions and 12 deletions

View File

@@ -14,7 +14,10 @@ from langchain_community.vectorstores.neo4j_vector import (
_get_search_index_query,
)
from langchain_community.vectorstores.utils import DistanceStrategy
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.integration_tests.vectorstores.fake_embeddings import (
AngularTwoDimensionalEmbeddings,
FakeEmbeddings,
)
from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import (
DOCUMENTS,
TYPE_1_FILTERING_TEST_CASES,
@@ -928,6 +931,45 @@ OPTIONS {indexConfig: {
drop_vector_indexes(docsearch)
def test_neo4j_max_marginal_relevance_search() -> None:
"""
Test end to end construction and MMR search.
The embedding function used here ensures `texts` become
the following vectors on a circle (numbered v0 through v3):
______ v2
/ \
/ | v1
v3 | . | query
| / v0
|______/ (N.B. very crude drawing)
With fetch_k==3 and k==2, when query is at (1, ),
one expects that v2 and v0 are returned (in some order).
"""
texts = ["-0.124", "+0.127", "+0.25", "+1.0"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = Neo4jVector.from_texts(
texts,
metadatas=metadatas,
embedding=AngularTwoDimensionalEmbeddings(),
pre_delete_collection=True,
)
expected_set = {
("+0.25", 2),
("-0.124", 0),
}
output = docsearch.max_marginal_relevance_search("0.0", k=2, fetch_k=3)
output_set = {
(mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output
}
assert output_set == expected_set
drop_vector_indexes(docsearch)
def test_neo4jvector_passing_graph_object() -> None:
"""Test end to end construction and search with passing graph object."""
graph = Neo4jGraph()