mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-18 17:11:25 +00:00
Add mmr to neo4j vector (#25765)
This commit is contained in:
parent
995305fdd5
commit
f359e6b0a5
@ -15,13 +15,17 @@ from typing import (
|
|||||||
Type,
|
Type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
|
||||||
from langchain_community.graphs import Neo4jGraph
|
from langchain_community.graphs import Neo4jGraph
|
||||||
from langchain_community.vectorstores.utils import DistanceStrategy
|
from langchain_community.vectorstores.utils import (
|
||||||
|
DistanceStrategy,
|
||||||
|
maximal_marginal_relevance,
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
|
DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
|
||||||
DISTANCE_MAPPING = {
|
DISTANCE_MAPPING = {
|
||||||
@ -1042,11 +1046,29 @@ class Neo4jVector(VectorStore):
|
|||||||
filter_params = {}
|
filter_params = {}
|
||||||
|
|
||||||
if self._index_type == IndexType.RELATIONSHIP:
|
if self._index_type == IndexType.RELATIONSHIP:
|
||||||
|
if kwargs.get("return_embeddings"):
|
||||||
|
default_retrieval = (
|
||||||
|
f"RETURN relationship.`{self.text_node_property}` AS text, score, "
|
||||||
|
f"relationship {{.*, `{self.text_node_property}`: Null, "
|
||||||
|
f"`{self.embedding_node_property}`: Null, id: Null, "
|
||||||
|
f"_embedding_: relationship.`{self.embedding_node_property}`}} "
|
||||||
|
"AS metadata"
|
||||||
|
)
|
||||||
|
else:
|
||||||
default_retrieval = (
|
default_retrieval = (
|
||||||
f"RETURN relationship.`{self.text_node_property}` AS text, score, "
|
f"RETURN relationship.`{self.text_node_property}` AS text, score, "
|
||||||
f"relationship {{.*, `{self.text_node_property}`: Null, "
|
f"relationship {{.*, `{self.text_node_property}`: Null, "
|
||||||
f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata"
|
f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if kwargs.get("return_embeddings"):
|
||||||
|
default_retrieval = (
|
||||||
|
f"RETURN node.`{self.text_node_property}` AS text, score, "
|
||||||
|
f"node {{.*, `{self.text_node_property}`: Null, "
|
||||||
|
f"`{self.embedding_node_property}`: Null, id: Null, "
|
||||||
|
f"_embedding_: node.`{self.embedding_node_property}`}} AS metadata"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
default_retrieval = (
|
default_retrieval = (
|
||||||
f"RETURN node.`{self.text_node_property}` AS text, score, "
|
f"RETURN node.`{self.text_node_property}` AS text, score, "
|
||||||
@ -1083,6 +1105,20 @@ class Neo4jVector(VectorStore):
|
|||||||
"Inspect the `retrieval_query` and ensure it doesn't "
|
"Inspect the `retrieval_query` and ensure it doesn't "
|
||||||
"return None for the `text` column"
|
"return None for the `text` column"
|
||||||
)
|
)
|
||||||
|
if kwargs.get("return_embeddings") and any(
|
||||||
|
result["metadata"]["_embedding_"] is None for result in results
|
||||||
|
):
|
||||||
|
if not self.retrieval_query:
|
||||||
|
raise ValueError(
|
||||||
|
f"Make sure that none of the `{self.embedding_node_property}` "
|
||||||
|
f"properties on nodes with label `{self.node_label}` "
|
||||||
|
"are missing or empty"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Inspect the `retrieval_query` and ensure it doesn't "
|
||||||
|
"return None for the `_embedding_` metadata column"
|
||||||
|
)
|
||||||
|
|
||||||
docs = [
|
docs = [
|
||||||
(
|
(
|
||||||
@ -1487,6 +1523,64 @@ class Neo4jVector(VectorStore):
|
|||||||
break
|
break
|
||||||
return store
|
return store
|
||||||
|
|
||||||
|
def max_marginal_relevance_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
filter: Optional[dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: search query text.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
filter: Filter on metadata properties, e.g.
|
||||||
|
{
|
||||||
|
"str_property": "foo",
|
||||||
|
"int_property": 123
|
||||||
|
}
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
# Embed the query
|
||||||
|
query_embedding = self.embedding.embed_query(query)
|
||||||
|
|
||||||
|
# Fetch the initial documents
|
||||||
|
got_docs = self.similarity_search_with_score_by_vector(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query=query,
|
||||||
|
k=fetch_k,
|
||||||
|
return_embeddings=True,
|
||||||
|
filter=filter,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the embeddings for the fetched documents
|
||||||
|
got_embeddings = [doc.metadata["_embedding_"] for doc, _ in got_docs]
|
||||||
|
|
||||||
|
# Select documents using maximal marginal relevance
|
||||||
|
selected_indices = maximal_marginal_relevance(
|
||||||
|
np.array(query_embedding), got_embeddings, lambda_mult=lambda_mult, k=k
|
||||||
|
)
|
||||||
|
selected_docs = [got_docs[i][0] for i in selected_indices]
|
||||||
|
|
||||||
|
# Remove embedding values from metadata
|
||||||
|
for doc in selected_docs:
|
||||||
|
del doc.metadata["_embedding_"]
|
||||||
|
|
||||||
|
return selected_docs
|
||||||
|
|
||||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||||
"""
|
"""
|
||||||
The 'correct' relevance function
|
The 'correct' relevance function
|
||||||
|
@ -14,7 +14,10 @@ from langchain_community.vectorstores.neo4j_vector import (
|
|||||||
_get_search_index_query,
|
_get_search_index_query,
|
||||||
)
|
)
|
||||||
from langchain_community.vectorstores.utils import DistanceStrategy
|
from langchain_community.vectorstores.utils import DistanceStrategy
|
||||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||||
|
AngularTwoDimensionalEmbeddings,
|
||||||
|
FakeEmbeddings,
|
||||||
|
)
|
||||||
from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import (
|
from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import (
|
||||||
DOCUMENTS,
|
DOCUMENTS,
|
||||||
TYPE_1_FILTERING_TEST_CASES,
|
TYPE_1_FILTERING_TEST_CASES,
|
||||||
@ -928,6 +931,45 @@ OPTIONS {indexConfig: {
|
|||||||
drop_vector_indexes(docsearch)
|
drop_vector_indexes(docsearch)
|
||||||
|
|
||||||
|
|
||||||
|
def test_neo4j_max_marginal_relevance_search() -> None:
|
||||||
|
"""
|
||||||
|
Test end to end construction and MMR search.
|
||||||
|
The embedding function used here ensures `texts` become
|
||||||
|
the following vectors on a circle (numbered v0 through v3):
|
||||||
|
|
||||||
|
______ v2
|
||||||
|
/ \
|
||||||
|
/ | v1
|
||||||
|
v3 | . | query
|
||||||
|
| / v0
|
||||||
|
|______/ (N.B. very crude drawing)
|
||||||
|
|
||||||
|
With fetch_k==3 and k==2, when query is at (1, ),
|
||||||
|
one expects that v2 and v0 are returned (in some order).
|
||||||
|
"""
|
||||||
|
texts = ["-0.124", "+0.127", "+0.25", "+1.0"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = Neo4jVector.from_texts(
|
||||||
|
texts,
|
||||||
|
metadatas=metadatas,
|
||||||
|
embedding=AngularTwoDimensionalEmbeddings(),
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_set = {
|
||||||
|
("+0.25", 2),
|
||||||
|
("-0.124", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
output = docsearch.max_marginal_relevance_search("0.0", k=2, fetch_k=3)
|
||||||
|
output_set = {
|
||||||
|
(mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output
|
||||||
|
}
|
||||||
|
assert output_set == expected_set
|
||||||
|
|
||||||
|
drop_vector_indexes(docsearch)
|
||||||
|
|
||||||
|
|
||||||
def test_neo4jvector_passing_graph_object() -> None:
|
def test_neo4jvector_passing_graph_object() -> None:
|
||||||
"""Test end to end construction and search with passing graph object."""
|
"""Test end to end construction and search with passing graph object."""
|
||||||
graph = Neo4jGraph()
|
graph = Neo4jGraph()
|
||||||
|
Loading…
Reference in New Issue
Block a user