From 866d6408afd712145ce961ad45eab8c80b86b5f1 Mon Sep 17 00:00:00 2001 From: Jib Date: Mon, 18 Mar 2024 15:43:50 -0400 Subject: [PATCH] mongodb[patch]: Remove embedding retrieval from mongodb payload (#19035) ## Description Returning the embedding is not necessary in the vector search functionality unless specified as a debugging step. This change defaults the behavior such that the server _only_ returns the embedding key if explicitly requested, such as in the case of `max_marginal_relevance_search`. - [x] **Add tests and docs**: If you're adding a new integration, please include * Added `test_from_documents_no_embedding_return` - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ --------- Co-authored-by: Erick Friis --- .../mongodb/langchain_mongodb/vectorstores.py | 29 +++++++----- .../integration_tests/test_vectorstores.py | 46 +++++++++++++++++++ 2 files changed, 64 insertions(+), 11 deletions(-) diff --git a/libs/partners/mongodb/langchain_mongodb/vectorstores.py b/libs/partners/mongodb/langchain_mongodb/vectorstores.py index afec8230713..af5501b5c8f 100644 --- a/libs/partners/mongodb/langchain_mongodb/vectorstores.py +++ b/libs/partners/mongodb/langchain_mongodb/vectorstores.py @@ -183,6 +183,8 @@ class MongoDBAtlasVectorSearch(VectorStore): k: int = 4, pre_filter: Optional[Dict] = None, post_filter_pipeline: Optional[List[Dict]] = None, + include_embedding: bool = False, + **kwargs: Any, ) -> List[Tuple[Document, float]]: params = { "queryVector": embedding, @@ -199,6 +201,11 @@ class MongoDBAtlasVectorSearch(VectorStore): query, {"$set": {"score": {"$meta": "vectorSearchScore"}}}, ] + + # Exclude the embedding key from the return payload + if not include_embedding: + pipeline.append({"$project": {self._embedding_key: 0}}) + if post_filter_pipeline is not None: pipeline.extend(post_filter_pipeline) cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type] @@ -215,6 +222,7 @@ class MongoDBAtlasVectorSearch(VectorStore): k: int = 4, pre_filter: Optional[Dict] = None, post_filter_pipeline: Optional[List[Dict]] = None, + **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return MongoDB documents most similar to the given query and their scores. @@ -238,6 +246,7 @@ class MongoDBAtlasVectorSearch(VectorStore): k=k, pre_filter=pre_filter, post_filter_pipeline=post_filter_pipeline, + **kwargs, ) return docs @@ -271,6 +280,7 @@ class MongoDBAtlasVectorSearch(VectorStore): k=k, pre_filter=pre_filter, post_filter_pipeline=post_filter_pipeline, + **kwargs, ) if additional and "similarity_score" in additional: @@ -310,20 +320,15 @@ class MongoDBAtlasVectorSearch(VectorStore): List of documents selected by maximal marginal relevance. """ query_embedding = self._embedding.embed_query(query) - docs = self._similarity_search_with_score( - query_embedding, - k=fetch_k, + return self.max_marginal_relevance_search_by_vector( + embedding=query_embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, pre_filter=pre_filter, post_filter_pipeline=post_filter_pipeline, + **kwargs, ) - mmr_doc_indexes = maximal_marginal_relevance( - np.array(query_embedding), - [doc.metadata[self._embedding_key] for doc, _ in docs], - k=k, - lambda_mult=lambda_mult, - ) - mmr_docs = [docs[i][0] for i in mmr_doc_indexes] - return mmr_docs @classmethod def from_texts( @@ -433,6 +438,8 @@ class MongoDBAtlasVectorSearch(VectorStore): k=fetch_k, pre_filter=pre_filter, post_filter_pipeline=post_filter_pipeline, + include_embedding=kwargs.pop("include_embedding", True), + **kwargs, ) mmr_doc_indexes = maximal_marginal_relevance( np.array(embedding), diff --git a/libs/partners/mongodb/tests/integration_tests/test_vectorstores.py b/libs/partners/mongodb/tests/integration_tests/test_vectorstores.py index b4b3fd28d47..16d4b17bafa 100644 --- a/libs/partners/mongodb/tests/integration_tests/test_vectorstores.py +++ b/libs/partners/mongodb/tests/integration_tests/test_vectorstores.py @@ -91,6 +91,52 @@ class TestMongoDBAtlasVectorSearch: # Check for the presence of the metadata key assert any([key.page_content == output[0].page_content for key in documents]) + def test_from_documents_no_embedding_return( + self, embedding_openai: Embeddings, collection: Any + ) -> None: + """Test end to end construction and search.""" + documents = [ + Document(page_content="Dogs are tough.", metadata={"a": 1}), + Document(page_content="Cats have fluff.", metadata={"b": 1}), + Document(page_content="What is a sandwich?", metadata={"c": 1}), + Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}), + ] + vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents( + documents, + embedding_openai, + collection=collection, + index_name=INDEX_NAME, + ) + output = vectorstore.similarity_search("Sandwich", k=1) + assert len(output) == 1 + # Check for presence of embedding in each document + assert all(["embedding" not in key.metadata for key in output]) + # Check for the presence of the metadata key + assert any([key.page_content == output[0].page_content for key in documents]) + + def test_from_documents_embedding_return( + self, embedding_openai: Embeddings, collection: Any + ) -> None: + """Test end to end construction and search.""" + documents = [ + Document(page_content="Dogs are tough.", metadata={"a": 1}), + Document(page_content="Cats have fluff.", metadata={"b": 1}), + Document(page_content="What is a sandwich?", metadata={"c": 1}), + Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}), + ] + vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents( + documents, + embedding_openai, + collection=collection, + index_name=INDEX_NAME, + ) + output = vectorstore.similarity_search("Sandwich", k=1, include_embedding=True) + assert len(output) == 1 + # Check for presence of embedding in each document + assert all([key.metadata.get("embedding") for key in output]) + # Check for the presence of the metadata key + assert any([key.page_content == output[0].page_content for key in documents]) + def test_from_texts(self, embedding_openai: Embeddings, collection: Any) -> None: texts = [ "Dogs are tough.",