mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 13:40:46 +00:00
Add MMR functionality to elasticsearch retriever (#11633)
Allows MMR functionality only for the case where we have access to the embedding function. Also allows for users to request for fields from elasticsearch store. These are added to the document metadata. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
ead9d5b55c
commit
361f8e1bc6
@ -14,10 +14,12 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.vectorstore import VectorStore
|
||||
from langchain.vectorstores.utils import DistanceStrategy
|
||||
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from elasticsearch import Elasticsearch
|
||||
@ -603,6 +605,67 @@ class ElasticsearchStore(VectorStore):
|
||||
results = self._search(query=query, k=k, filter=filter, **kwargs)
|
||||
return [doc for doc, _ in results]
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
fields: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query (str): Text to look up documents similar to.
|
||||
k (int): Number of Documents to return. Defaults to 4.
|
||||
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult (float): Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
fields: Other fields to get from elasticsearch source. These fields
|
||||
will be added to the document metadata.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
if self.embedding is None:
|
||||
raise ValueError("You must provide an embedding function to perform MMR")
|
||||
remove_vector_query_field_from_metadata = True
|
||||
if fields is None:
|
||||
fields = [self.vector_query_field]
|
||||
elif self.vector_query_field not in fields:
|
||||
fields.append(self.vector_query_field)
|
||||
else:
|
||||
remove_vector_query_field_from_metadata = False
|
||||
|
||||
# Embed the query
|
||||
query_embedding = self.embedding.embed_query(query)
|
||||
|
||||
# Fetch the initial documents
|
||||
got_docs = self._search(
|
||||
query_vector=query_embedding, k=fetch_k, fields=fields, **kwargs
|
||||
)
|
||||
|
||||
# Get the embeddings for the fetched documents
|
||||
got_embeddings = [doc.metadata[self.vector_query_field] for doc, _ in got_docs]
|
||||
|
||||
# Select documents using maximal marginal relevance
|
||||
selected_indices = maximal_marginal_relevance(
|
||||
np.array(query_embedding), got_embeddings, lambda_mult=lambda_mult, k=k
|
||||
)
|
||||
selected_docs = [got_docs[i][0] for i in selected_indices]
|
||||
|
||||
if remove_vector_query_field_from_metadata:
|
||||
for doc in selected_docs:
|
||||
del doc.metadata["vector"]
|
||||
|
||||
return selected_docs
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, query: str, k: int = 4, filter: Optional[List[dict]] = None, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
@ -665,7 +728,10 @@ class ElasticsearchStore(VectorStore):
|
||||
List of Documents most similar to the query and score for each
|
||||
"""
|
||||
if fields is None:
|
||||
fields = ["metadata"]
|
||||
fields = []
|
||||
|
||||
if "metadata" not in fields:
|
||||
fields.append("metadata")
|
||||
|
||||
if self.query_field not in fields:
|
||||
fields.append(self.query_field)
|
||||
@ -689,7 +755,6 @@ class ElasticsearchStore(VectorStore):
|
||||
if custom_query is not None:
|
||||
query_body = custom_query(query_body, query)
|
||||
logger.debug(f"Calling custom_query, Query body now: {query_body}")
|
||||
|
||||
# Perform the kNN search on the Elasticsearch index and return the results.
|
||||
response = self.client.search(
|
||||
index=self.index_name,
|
||||
@ -698,18 +763,24 @@ class ElasticsearchStore(VectorStore):
|
||||
source=fields,
|
||||
)
|
||||
|
||||
hits = [hit for hit in response["hits"]["hits"]]
|
||||
docs_and_scores = [
|
||||
(
|
||||
Document(
|
||||
page_content=hit["_source"][self.query_field],
|
||||
metadata=hit["_source"]["metadata"],
|
||||
),
|
||||
hit["_score"],
|
||||
)
|
||||
for hit in hits
|
||||
]
|
||||
docs_and_scores = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
for field in fields:
|
||||
if field in hit["_source"] and field not in [
|
||||
"metadata",
|
||||
self.query_field,
|
||||
]:
|
||||
hit["_source"]["metadata"][field] = hit["_source"][field]
|
||||
|
||||
docs_and_scores.append(
|
||||
(
|
||||
Document(
|
||||
page_content=hit["_source"][self.query_field],
|
||||
metadata=hit["_source"]["metadata"],
|
||||
),
|
||||
hit["_score"],
|
||||
)
|
||||
)
|
||||
return docs_and_scores
|
||||
|
||||
def delete(
|
||||
|
@ -385,6 +385,39 @@ class TestElasticsearch:
|
||||
distance_strategy="NOT_A_STRATEGY",
|
||||
)
|
||||
|
||||
def test_max_marginal_relevance_search(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test max marginal relevance search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
strategy=ElasticsearchStore.ExactRetrievalStrategy(),
|
||||
)
|
||||
|
||||
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=3, fetch_k=3)
|
||||
sim_output = docsearch.similarity_search(texts[0], k=3)
|
||||
assert mmr_output == sim_output
|
||||
|
||||
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=2, fetch_k=3)
|
||||
assert len(mmr_output) == 2
|
||||
assert mmr_output[0].page_content == texts[0]
|
||||
assert mmr_output[1].page_content == texts[1]
|
||||
|
||||
mmr_output = docsearch.max_marginal_relevance_search(
|
||||
texts[0], k=2, fetch_k=3, lambda_mult=0.1 # more diversity
|
||||
)
|
||||
assert len(mmr_output) == 2
|
||||
assert mmr_output[0].page_content == texts[0]
|
||||
assert mmr_output[1].page_content == texts[2]
|
||||
|
||||
# if fetch_k < k, then the output will be less than k
|
||||
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=3, fetch_k=2)
|
||||
assert len(mmr_output) == 2
|
||||
|
||||
def test_similarity_search_approx_with_hybrid_search(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user