From e26906c1dc8c867fdfd2cbb1a67536f0fd5f7d99 Mon Sep 17 00:00:00 2001 From: Michael Landis Date: Mon, 4 Dec 2023 16:50:23 -0800 Subject: [PATCH] feat: implement max marginal relevance for momento vector index (#13619) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Description** Implements `max_marginal_relevance_search` and `max_marginal_relevance_search_by_vector` for the Momento Vector Index vectorstore. Additionally bumps the `momento` dependency in the lock file and adds logging to the implementation. **Dependencies** ✅ updates `momento` dependency in lock file **Tag maintainer** @baskaryan **Twitter handle** Please tag @momentohq for Momento Vector Index and @mloml for the contribution 🙇 --- .../vectorstores/momento_vector_index.py | 94 ++++++++++++++++++- libs/langchain/poetry.lock | 15 ++- .../vectorstores/test_momento_vector_index.py | 24 ++++- 3 files changed, 120 insertions(+), 13 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/momento_vector_index.py b/libs/langchain/langchain/vectorstores/momento_vector_index.py index 8368fd6fc15..a969fa739f8 100644 --- a/libs/langchain/langchain/vectorstores/momento_vector_index.py +++ b/libs/langchain/langchain/vectorstores/momento_vector_index.py @@ -1,3 +1,4 @@ +import logging from typing import ( TYPE_CHECKING, Any, @@ -11,15 +12,17 @@ from typing import ( ) from uuid import uuid4 +import numpy as np from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore from langchain.utils import get_from_env -from langchain.vectorstores.utils import DistanceStrategy +from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance VST = TypeVar("VST", bound="VectorStore") +logger = logging.getLogger(__name__) if TYPE_CHECKING: from momento import PreviewVectorIndexClient @@ -75,9 +78,8 @@ class MomentoVectorIndex(VectorStore): index_name (str, optional): The name of the index to store the documents in. Defaults to "default". distance_strategy (DistanceStrategy, optional): The distance strategy to - use. Defaults to DistanceStrategy.COSINE. If you select - DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses the squared - Euclidean distance. + use. If you select DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses + the squared Euclidean distance. Defaults to DistanceStrategy.COSINE. text_field (str, optional): The name of the metadata field to store the original text in. Defaults to "text". ensure_index_exists (bool, optional): Whether to ensure that the index @@ -125,6 +127,7 @@ class MomentoVectorIndex(VectorStore): elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE: similarity_metric = SimilarityMetric.EUCLIDEAN_SIMILARITY else: + logger.error(f"Distance strategy {self.distance_strategy} not implemented.") raise ValueError( f"Distance strategy {self.distance_strategy} not implemented." ) @@ -137,8 +140,10 @@ class MomentoVectorIndex(VectorStore): elif isinstance(response, CreateIndex.IndexAlreadyExists): return False elif isinstance(response, CreateIndex.Error): + logger.error(f"Error creating index: {response.inner_exception}") raise response.inner_exception else: + logger.error(f"Unexpected response: {response}") raise Exception(f"Unexpected response: {response}") def add_texts( @@ -331,6 +336,87 @@ class MomentoVectorIndex(VectorStore): ) return [doc for doc, _ in results] + def max_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + Returns: + List of Documents selected by maximal marginal relevance. + """ + from momento.requests.vector_index import ALL_METADATA + from momento.responses.vector_index import SearchAndFetchVectors + + response = self._client.search_and_fetch_vectors( + self.index_name, embedding, top_k=fetch_k, metadata_fields=ALL_METADATA + ) + + if isinstance(response, SearchAndFetchVectors.Success): + pass + elif isinstance(response, SearchAndFetchVectors.Error): + logger.error(f"Error searching and fetching vectors: {response}") + return [] + else: + logger.error(f"Unexpected response: {response}") + raise Exception(f"Unexpected response: {response}") + + mmr_selected = maximal_marginal_relevance( + query_embedding=np.array([embedding], dtype=np.float32), + embedding_list=[hit.vector for hit in response.hits], + lambda_mult=lambda_mult, + k=k, + ) + selected = [response.hits[i].metadata for i in mmr_selected] + return [ + Document(page_content=metadata.pop(self.text_field, ""), metadata=metadata) # type: ignore # noqa: E501 + for metadata in selected + ] + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + Returns: + List of Documents selected by maximal marginal relevance. + """ + embedding = self._embedding.embed_query(query) + return self.max_marginal_relevance_search_by_vector( + embedding, k, fetch_k, lambda_mult, **kwargs + ) + @classmethod def from_texts( cls: Type[VST], diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index 4c983673e8c..8622ddc080d 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -3936,7 +3936,6 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, - {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] [[package]] @@ -4958,29 +4957,29 @@ files = [ [[package]] name = "momento" -version = "1.13.0" +version = "1.14.1" description = "SDK for Momento" optional = true python-versions = ">=3.7,<4.0" files = [ - {file = "momento-1.13.0-py3-none-any.whl", hash = "sha256:dd5ace5b8d679e882afcefaa16bc413973c270b0a7a1c6c45f3eb60b0b9526de"}, - {file = "momento-1.13.0.tar.gz", hash = "sha256:39419627542b8f5997a777ff91aa3aaf6406b7d76fb83cd84284a0f7d1f9e356"}, + {file = "momento-1.14.1-py3-none-any.whl", hash = "sha256:241e46669e39c19627396f2b2b027a912861f1b8097fc9f97b05b76b3d90d199"}, + {file = "momento-1.14.1.tar.gz", hash = "sha256:d200a5e7463f7746a8a611474af1c245183d7ddf9346d9592760b78b6e801560"}, ] [package.dependencies] grpcio = ">=1.46.0,<2.0.0" -momento-wire-types = ">=0.91.1,<0.92.0" +momento-wire-types = ">=0.96.0,<0.97.0" pyjwt = ">=2.4.0,<3.0.0" [[package]] name = "momento-wire-types" -version = "0.91.4" +version = "0.96.0" description = "Momento Client Proto Generated Files" optional = true python-versions = ">=3.7,<4.0" files = [ - {file = "momento_wire_types-0.91.4-py3-none-any.whl", hash = "sha256:f296249693de2f6c383a397e7616b84dd83dfd466743d34b035b90865000a2a8"}, - {file = "momento_wire_types-0.91.4.tar.gz", hash = "sha256:de8cd14a12835d95997eb9b753ea47e1a5d2916658ec9320e416da8bd835fdff"}, + {file = "momento_wire_types-0.96.0-py3-none-any.whl", hash = "sha256:93dc0e3c31bbe1f664ce33974f235bc20e63b5e35ea8e118f0c5e5ed3cda7709"}, + {file = "momento_wire_types-0.96.0.tar.gz", hash = "sha256:9c6c839c698741c54b9fc3a4fe0f82094ea5102418b02bb271ed6e64ea6d7d9e"}, ] [package.dependencies] diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_momento_vector_index.py b/libs/langchain/tests/integration_tests/vectorstores/test_momento_vector_index.py index 7689088ac51..c4f20cf2e11 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_momento_vector_index.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_momento_vector_index.py @@ -125,7 +125,7 @@ def test_from_texts_with_metadatas( def test_from_texts_with_scores(vector_store: MomentoVectorIndex) -> None: - # """Test end to end construction and search with scores and IDs.""" + """Test end to end construction and search with scores and IDs.""" texts = ["apple", "orange", "hammer"] metadatas = [{"page": f"{i}"} for i in range(len(texts))] @@ -162,3 +162,25 @@ def test_add_documents_with_ids(vector_store: MomentoVectorIndex) -> None: ) assert isinstance(response, Search.Success) assert [hit.id for hit in response.hits] == ids + + +def test_max_marginal_relevance_search(vector_store: MomentoVectorIndex) -> None: + """Test max marginal relevance search.""" + pepperoni_pizza = "pepperoni pizza" + cheese_pizza = "cheese pizza" + hot_dog = "hot dog" + + vector_store.add_texts([pepperoni_pizza, cheese_pizza, hot_dog]) + wait() + search_results = vector_store.similarity_search("pizza", k=2) + + assert search_results == [ + Document(page_content=pepperoni_pizza, metadata={}), + Document(page_content=cheese_pizza, metadata={}), + ] + + search_results = vector_store.max_marginal_relevance_search(query="pizza", k=2) + assert search_results == [ + Document(page_content=pepperoni_pizza, metadata={}), + Document(page_content=hot_dog, metadata={}), + ]