diff --git a/libs/langchain/langchain/vectorstores/elasticsearch.py b/libs/langchain/langchain/vectorstores/elasticsearch.py index f76db9ccab1..f8a198dec94 100644 --- a/libs/langchain/langchain/vectorstores/elasticsearch.py +++ b/libs/langchain/langchain/vectorstores/elasticsearch.py @@ -14,10 +14,12 @@ from typing import ( Union, ) +import numpy as np + from langchain.docstore.document import Document from langchain.schema.embeddings import Embeddings from langchain.schema.vectorstore import VectorStore -from langchain.vectorstores.utils import DistanceStrategy +from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance if TYPE_CHECKING: from elasticsearch import Elasticsearch @@ -603,6 +605,67 @@ class ElasticsearchStore(VectorStore): results = self._search(query=query, k=k, filter=filter, **kwargs) return [doc for doc, _ in results] + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + fields: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + fields: Other fields to get from elasticsearch source. These fields + will be added to the document metadata. + + Returns: + List[Document]: A list of Documents selected by maximal marginal relevance. + """ + if self.embedding is None: + raise ValueError("You must provide an embedding function to perform MMR") + remove_vector_query_field_from_metadata = True + if fields is None: + fields = [self.vector_query_field] + elif self.vector_query_field not in fields: + fields.append(self.vector_query_field) + else: + remove_vector_query_field_from_metadata = False + + # Embed the query + query_embedding = self.embedding.embed_query(query) + + # Fetch the initial documents + got_docs = self._search( + query_vector=query_embedding, k=fetch_k, fields=fields, **kwargs + ) + + # Get the embeddings for the fetched documents + got_embeddings = [doc.metadata[self.vector_query_field] for doc, _ in got_docs] + + # Select documents using maximal marginal relevance + selected_indices = maximal_marginal_relevance( + np.array(query_embedding), got_embeddings, lambda_mult=lambda_mult, k=k + ) + selected_docs = [got_docs[i][0] for i in selected_indices] + + if remove_vector_query_field_from_metadata: + for doc in selected_docs: + del doc.metadata["vector"] + + return selected_docs + def similarity_search_with_score( self, query: str, k: int = 4, filter: Optional[List[dict]] = None, **kwargs: Any ) -> List[Tuple[Document, float]]: @@ -665,7 +728,10 @@ class ElasticsearchStore(VectorStore): List of Documents most similar to the query and score for each """ if fields is None: - fields = ["metadata"] + fields = [] + + if "metadata" not in fields: + fields.append("metadata") if self.query_field not in fields: fields.append(self.query_field) @@ -689,7 +755,6 @@ class ElasticsearchStore(VectorStore): if custom_query is not None: query_body = custom_query(query_body, query) logger.debug(f"Calling custom_query, Query body now: {query_body}") - # Perform the kNN search on the Elasticsearch index and return the results. response = self.client.search( index=self.index_name, @@ -698,18 +763,24 @@ class ElasticsearchStore(VectorStore): source=fields, ) - hits = [hit for hit in response["hits"]["hits"]] - docs_and_scores = [ - ( - Document( - page_content=hit["_source"][self.query_field], - metadata=hit["_source"]["metadata"], - ), - hit["_score"], - ) - for hit in hits - ] + docs_and_scores = [] + for hit in response["hits"]["hits"]: + for field in fields: + if field in hit["_source"] and field not in [ + "metadata", + self.query_field, + ]: + hit["_source"]["metadata"][field] = hit["_source"][field] + docs_and_scores.append( + ( + Document( + page_content=hit["_source"][self.query_field], + metadata=hit["_source"]["metadata"], + ), + hit["_score"], + ) + ) return docs_and_scores def delete( diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_elasticsearch.py b/libs/langchain/tests/integration_tests/vectorstores/test_elasticsearch.py index 64ce2dd77bb..8824d94e3a9 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_elasticsearch.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_elasticsearch.py @@ -385,6 +385,39 @@ class TestElasticsearch: distance_strategy="NOT_A_STRATEGY", ) + def test_max_marginal_relevance_search( + self, elasticsearch_connection: dict, index_name: str + ) -> None: + """Test max marginal relevance search.""" + texts = ["foo", "bar", "baz"] + docsearch = ElasticsearchStore.from_texts( + texts, + FakeEmbeddings(), + **elasticsearch_connection, + index_name=index_name, + strategy=ElasticsearchStore.ExactRetrievalStrategy(), + ) + + mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=3, fetch_k=3) + sim_output = docsearch.similarity_search(texts[0], k=3) + assert mmr_output == sim_output + + mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=2, fetch_k=3) + assert len(mmr_output) == 2 + assert mmr_output[0].page_content == texts[0] + assert mmr_output[1].page_content == texts[1] + + mmr_output = docsearch.max_marginal_relevance_search( + texts[0], k=2, fetch_k=3, lambda_mult=0.1 # more diversity + ) + assert len(mmr_output) == 2 + assert mmr_output[0].page_content == texts[0] + assert mmr_output[1].page_content == texts[2] + + # if fetch_k < k, then the output will be less than k + mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=3, fetch_k=2) + assert len(mmr_output) == 2 + def test_similarity_search_approx_with_hybrid_search( self, elasticsearch_connection: dict, index_name: str ) -> None: