adding max_marginal_relevance_search method to MongoDBAtlasVectorSearch (#7310)

Adding a maximal_marginal_relevance method to the
MongoDBAtlasVectorSearch vectorstore enhances the user experience by
providing more diverse search results

Issue: #7304
This commit is contained in:
Paul-Emile Brotons 2023-07-10 04:04:19 -04:00 committed by GitHub
parent 04cddfba0d
commit d2cf0d16b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 349 additions and 251 deletions

View File

@ -1,10 +1,14 @@
{ {
"cells":[ "cells":[
{ {
"attachments": {}, "attachments":{
},
"cell_type":"markdown", "cell_type":"markdown",
"id":"683953b3", "id":"683953b3",
"metadata": {}, "metadata":{
},
"source":[ "source":[
"# MongoDB Atlas\n", "# MongoDB Atlas\n",
"\n", "\n",
@ -23,9 +27,13 @@
"execution_count":null, "execution_count":null,
"id":"b4c41cad-08ef-4f72-a545-2151e4598efe", "id":"b4c41cad-08ef-4f72-a545-2151e4598efe",
"metadata":{ "metadata":{
"tags": [] "tags":[
]
}, },
"outputs": [], "outputs":[
],
"source":[ "source":[
"!pip install pymongo" "!pip install pymongo"
] ]
@ -34,21 +42,28 @@
"cell_type":"code", "cell_type":"code",
"execution_count":null, "execution_count":null,
"id":"c1e38361-c1fe-4ac6-86e9-c90ebaf7ae87", "id":"c1e38361-c1fe-4ac6-86e9-c90ebaf7ae87",
"metadata": {}, "metadata":{
"outputs": [],
},
"outputs":[
],
"source":[ "source":[
"import os\n", "import os\n",
"import getpass\n", "import getpass\n",
"\n", "\n",
"MONGODB_ATLAS_CLUSTER_URI = getpass.getpass(\"MongoDB Atlas Cluster URI:\")\n", "MONGODB_ATLAS_CLUSTER_URI = getpass.getpass(\"MongoDB Atlas Cluster URI:\")\n"
"MONGODB_ATLAS_CLUSTER_URI = os.environ[\"MONGODB_ATLAS_CLUSTER_URI\"]"
] ]
}, },
{ {
"attachments": {}, "attachments":{
},
"cell_type":"markdown", "cell_type":"markdown",
"id":"457ace44-1d95-4001-9dd5-78811ab208ad", "id":"457ace44-1d95-4001-9dd5-78811ab208ad",
"metadata": {}, "metadata":{
},
"source":[ "source":[
"We want to use `OpenAIEmbeddings` so we need to set up our OpenAI API Key. " "We want to use `OpenAIEmbeddings` so we need to set up our OpenAI API Key. "
] ]
@ -57,18 +72,25 @@
"cell_type":"code", "cell_type":"code",
"execution_count":null, "execution_count":null,
"id":"2d8f240d", "id":"2d8f240d",
"metadata": {}, "metadata":{
"outputs": [],
},
"outputs":[
],
"source":[ "source":[
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")\n", "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")\n"
"OPENAI_API_KEY = os.environ[\"OPENAI_API_KEY\"]"
] ]
}, },
{ {
"attachments": {}, "attachments":{
},
"cell_type":"markdown", "cell_type":"markdown",
"id":"1f3ecc42", "id":"1f3ecc42",
"metadata": {}, "metadata":{
},
"source":[ "source":[
"Now, let's create a vector search index on your cluster. In the below example, `embedding` is the name of the field that contains the embedding vector. Please refer to the [documentation](https://www.mongodb.com/docs/atlas/atlas-search/define-field-mappings-for-vector-search) to get more details on how to define an Atlas Vector Search index.\n", "Now, let's create a vector search index on your cluster. In the below example, `embedding` is the name of the field that contains the embedding vector. Please refer to the [documentation](https://www.mongodb.com/docs/atlas/atlas-search/define-field-mappings-for-vector-search) to get more details on how to define an Atlas Vector Search index.\n",
"You can name the index `langchain_demo` and create the index on the namespace `lanchain_db.langchain_col`. Finally, write the following definition in the JSON editor on MongoDB Atlas:\n", "You can name the index `langchain_demo` and create the index on the namespace `lanchain_db.langchain_col`. Finally, write the following definition in the JSON editor on MongoDB Atlas:\n",
@ -94,23 +116,17 @@
"execution_count":2, "execution_count":2,
"id":"aac9563e", "id":"aac9563e",
"metadata":{ "metadata":{
"tags": [] "tags":[
]
}, },
"outputs": [], "outputs":[
],
"source":[ "source":[
"from langchain.embeddings.openai import OpenAIEmbeddings\n", "from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n", "from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import MongoDBAtlasVectorSearch\n", "from langchain.vectorstores import MongoDBAtlasVectorSearch\n",
"from langchain.document_loaders import TextLoader"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a3c3999a",
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import TextLoader\n", "from langchain.document_loaders import TextLoader\n",
"\n", "\n",
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n", "loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
@ -125,8 +141,12 @@
"cell_type":"code", "cell_type":"code",
"execution_count":null, "execution_count":null,
"id":"6e104aee", "id":"6e104aee",
"metadata": {}, "metadata":{
"outputs": [],
},
"outputs":[
],
"source":[ "source":[
"from pymongo import MongoClient\n", "from pymongo import MongoClient\n",
"\n", "\n",
@ -152,51 +172,48 @@
"cell_type":"code", "cell_type":"code",
"execution_count":null, "execution_count":null,
"id":"9c608226", "id":"9c608226",
"metadata": {}, "metadata":{
"outputs": [],
},
"outputs":[
],
"source":[ "source":[
"print(docs[0].page_content)" "print(docs[0].page_content)"
] ]
}, },
{ {
"attachments": {}, "attachments":{
},
"cell_type":"markdown", "cell_type":"markdown",
"id":"851a2ec9-9390-49a4-8412-3e132c9f789d", "id":"851a2ec9-9390-49a4-8412-3e132c9f789d",
"metadata": {}, "metadata":{
},
"source":[ "source":[
"You can reuse the vector search index you created, make sure the `OPENAI_API_KEY` environment variable is set up, then execute another query." "You can also instantiate the vector store directly and execute a query as follows:"
] ]
}, },
{ {
"cell_type":"code", "cell_type":"code",
"execution_count":null, "execution_count":null,
"id":"6336fe79-3e73-48be-b20a-0ff1bb6a4399", "id":"6336fe79-3e73-48be-b20a-0ff1bb6a4399",
"metadata": {}, "metadata":{
"outputs": [],
},
"outputs":[
],
"source":[ "source":[
"from pymongo import MongoClient\n",
"from langchain.vectorstores import MongoDBAtlasVectorSearch\n",
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"import os\n",
"\n",
"MONGODB_ATLAS_URI = os.environ[\"MONGODB_ATLAS_URI\"]\n",
"\n",
"# initialize MongoDB python client\n",
"client = MongoClient(MONGODB_ATLAS_URI)\n",
"\n",
"db_name = \"langchain_db\"\n",
"collection_name = \"langchain_col\"\n",
"collection = client[db_name][collection_name]\n",
"index_name = \"langchain_demo\"\n",
"\n",
"# initialize vector store\n", "# initialize vector store\n",
"vectorStore = MongoDBAtlasVectorSearch(\n", "vectorstore = MongoDBAtlasVectorSearch(\n",
" collection, OpenAIEmbeddings(), index_name=index_name\n", " collection, OpenAIEmbeddings(), index_name=index_name\n",
")\n", ")\n",
"\n", "\n",
"# perform a similarity search between a query and the ingested documents\n", "# perform a similarity search between a query and the ingested documents\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n", "query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = vectorStore.similarity_search(query)\n", "docs = vectorstore.similarity_search(query)\n",
"\n", "\n",
"print(docs[0].page_content)" "print(docs[0].page_content)"
] ]

View File

@ -14,9 +14,12 @@ from typing import (
Union, Union,
) )
import numpy as np
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING: if TYPE_CHECKING:
from pymongo.collection import Collection from pymongo.collection import Collection
@ -137,6 +140,39 @@ class MongoDBAtlasVectorSearch(VectorStore):
insert_result = self._collection.insert_many(to_insert) insert_result = self._collection.insert_many(to_insert)
return insert_result.inserted_ids return insert_result.inserted_ids
def _similarity_search_with_score(
self,
embedding: List[float],
k: int = 4,
pre_filter: Optional[dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
) -> List[Tuple[Document, float]]:
knn_beta = {
"vector": embedding,
"path": self._embedding_key,
"k": k,
}
if pre_filter:
knn_beta["filter"] = pre_filter
pipeline = [
{
"$search": {
"index": self._index_name,
"knnBeta": knn_beta,
}
},
{"$set": {"score": {"$meta": "searchScore"}}},
]
if post_filter_pipeline is not None:
pipeline.extend(post_filter_pipeline)
cursor = self._collection.aggregate(pipeline)
docs = []
for res in cursor:
text = res.pop(self._text_key)
score = res.pop("score")
docs.append((Document(page_content=text, metadata=res), score))
return docs
def similarity_search_with_score( def similarity_search_with_score(
self, self,
query: str, query: str,
@ -165,30 +201,13 @@ class MongoDBAtlasVectorSearch(VectorStore):
Returns: Returns:
List of Documents most similar to the query and score for each List of Documents most similar to the query and score for each
""" """
knn_beta = { embedding = self._embedding.embed_query(query)
"vector": self._embedding.embed_query(query), docs = self._similarity_search_with_score(
"path": self._embedding_key, embedding,
"k": k, k=k,
} pre_filter=pre_filter,
if pre_filter: post_filter_pipeline=post_filter_pipeline,
knn_beta["filter"] = pre_filter )
pipeline = [
{
"$search": {
"index": self._index_name,
"knnBeta": knn_beta,
}
},
{"$project": {"score": {"$meta": "searchScore"}, self._embedding_key: 0}},
]
if post_filter_pipeline is not None:
pipeline.extend(post_filter_pipeline)
cursor = self._collection.aggregate(pipeline)
docs = []
for res in cursor:
text = res.pop(self._text_key)
score = res.pop("score")
docs.append((Document(page_content=text, metadata=res), score))
return docs return docs
def similarity_search( def similarity_search(
@ -227,6 +246,53 @@ class MongoDBAtlasVectorSearch(VectorStore):
) )
return [doc for doc, _ in docs_and_scores] return [doc for doc, _ in docs_and_scores]
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
pre_filter: Optional[dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Optional Number of Documents to return. Defaults to 4.
fetch_k: Optional Number of Documents to fetch before passing to MMR
algorithm. Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
pre_filter: Optional Dictionary of argument(s) to prefilter on document
fields.
post_filter_pipeline: Optional Pipeline of MongoDB aggregation stages
following the knnBeta search.
Returns:
List of Documents selected by maximal marginal relevance.
"""
query_embedding = self._embedding.embed_query(query)
docs = self._similarity_search_with_score(
query_embedding,
k=fetch_k,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
)
mmr_doc_indexes = maximal_marginal_relevance(
np.array(query_embedding),
[doc.metadata[self._embedding_key] for doc, _ in docs],
k=k,
lambda_mult=lambda_mult,
)
mmr_docs = [docs[i][0] for i in mmr_doc_indexes]
return mmr_docs
@classmethod @classmethod
def from_texts( def from_texts(
cls, cls,
@ -252,7 +318,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
from langchain.vectorstores import MongoDBAtlasVectorSearch from langchain.vectorstores import MongoDBAtlasVectorSearch
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
client = MongoClient("<YOUR-CONNECTION-STRING>") mongo_client = MongoClient("<YOUR-CONNECTION-STRING>")
collection = mongo_client["<db_name>"]["<collection_name>"] collection = mongo_client["<db_name>"]["<collection_name>"]
embeddings = OpenAIEmbeddings() embeddings = OpenAIEmbeddings()
vectorstore = MongoDBAtlasVectorSearch.from_texts( vectorstore = MongoDBAtlasVectorSearch.from_texts(
@ -264,6 +330,6 @@ class MongoDBAtlasVectorSearch(VectorStore):
""" """
if collection is None: if collection is None:
raise ValueError("Must provide 'collection' named parameter.") raise ValueError("Must provide 'collection' named parameter.")
vecstore = cls(collection, embedding, **kwargs) vectorstore = cls(collection, embedding, **kwargs)
vecstore.add_texts(texts, metadatas=metadatas) vectorstore.add_texts(texts, metadatas=metadatas)
return vecstore return vectorstore

View File

@ -119,3 +119,18 @@ class TestMongoDBAtlasVectorSearch:
"Sandwich", k=1, pre_filter={"range": {"lte": 0, "path": "c"}} "Sandwich", k=1, pre_filter={"range": {"lte": 0, "path": "c"}}
) )
assert output == [] assert output == []
def test_mmr(self, embedding_openai: Embeddings) -> None:
texts = ["foo", "foo", "fou", "foy"]
vectorstore = MongoDBAtlasVectorSearch.from_texts(
texts,
embedding_openai,
collection=collection,
index_name=INDEX_NAME,
)
sleep(1) # waits for mongot to update Lucene's index
query = "foo"
output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1)
assert len(output) == len(texts)
assert output[0].page_content == "foo"
assert output[1].page_content != "foo"