mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 15:16:21 +00:00
adding max_marginal_relevance_search method to MongoDBAtlasVectorSearch (#7310)
Adding a maximal_marginal_relevance method to the MongoDBAtlasVectorSearch vectorstore enhances the user experience by providing more diverse search results Issue: #7304
This commit is contained in:
parent
04cddfba0d
commit
d2cf0d16b3
@ -1,10 +1,14 @@
|
|||||||
{
|
{
|
||||||
"cells":[
|
"cells":[
|
||||||
{
|
{
|
||||||
"attachments": {},
|
"attachments":{
|
||||||
|
|
||||||
|
},
|
||||||
"cell_type":"markdown",
|
"cell_type":"markdown",
|
||||||
"id":"683953b3",
|
"id":"683953b3",
|
||||||
"metadata": {},
|
"metadata":{
|
||||||
|
|
||||||
|
},
|
||||||
"source":[
|
"source":[
|
||||||
"# MongoDB Atlas\n",
|
"# MongoDB Atlas\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -23,9 +27,13 @@
|
|||||||
"execution_count":null,
|
"execution_count":null,
|
||||||
"id":"b4c41cad-08ef-4f72-a545-2151e4598efe",
|
"id":"b4c41cad-08ef-4f72-a545-2151e4598efe",
|
||||||
"metadata":{
|
"metadata":{
|
||||||
"tags": []
|
"tags":[
|
||||||
|
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs":[
|
||||||
|
|
||||||
|
],
|
||||||
"source":[
|
"source":[
|
||||||
"!pip install pymongo"
|
"!pip install pymongo"
|
||||||
]
|
]
|
||||||
@ -34,21 +42,28 @@
|
|||||||
"cell_type":"code",
|
"cell_type":"code",
|
||||||
"execution_count":null,
|
"execution_count":null,
|
||||||
"id":"c1e38361-c1fe-4ac6-86e9-c90ebaf7ae87",
|
"id":"c1e38361-c1fe-4ac6-86e9-c90ebaf7ae87",
|
||||||
"metadata": {},
|
"metadata":{
|
||||||
"outputs": [],
|
|
||||||
|
},
|
||||||
|
"outputs":[
|
||||||
|
|
||||||
|
],
|
||||||
"source":[
|
"source":[
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"import getpass\n",
|
"import getpass\n",
|
||||||
"\n",
|
"\n",
|
||||||
"MONGODB_ATLAS_CLUSTER_URI = getpass.getpass(\"MongoDB Atlas Cluster URI:\")\n",
|
"MONGODB_ATLAS_CLUSTER_URI = getpass.getpass(\"MongoDB Atlas Cluster URI:\")\n"
|
||||||
"MONGODB_ATLAS_CLUSTER_URI = os.environ[\"MONGODB_ATLAS_CLUSTER_URI\"]"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
"attachments":{
|
||||||
|
|
||||||
|
},
|
||||||
"cell_type":"markdown",
|
"cell_type":"markdown",
|
||||||
"id":"457ace44-1d95-4001-9dd5-78811ab208ad",
|
"id":"457ace44-1d95-4001-9dd5-78811ab208ad",
|
||||||
"metadata": {},
|
"metadata":{
|
||||||
|
|
||||||
|
},
|
||||||
"source":[
|
"source":[
|
||||||
"We want to use `OpenAIEmbeddings` so we need to set up our OpenAI API Key. "
|
"We want to use `OpenAIEmbeddings` so we need to set up our OpenAI API Key. "
|
||||||
]
|
]
|
||||||
@ -57,18 +72,25 @@
|
|||||||
"cell_type":"code",
|
"cell_type":"code",
|
||||||
"execution_count":null,
|
"execution_count":null,
|
||||||
"id":"2d8f240d",
|
"id":"2d8f240d",
|
||||||
"metadata": {},
|
"metadata":{
|
||||||
"outputs": [],
|
|
||||||
|
},
|
||||||
|
"outputs":[
|
||||||
|
|
||||||
|
],
|
||||||
"source":[
|
"source":[
|
||||||
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")\n",
|
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")\n"
|
||||||
"OPENAI_API_KEY = os.environ[\"OPENAI_API_KEY\"]"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
"attachments":{
|
||||||
|
|
||||||
|
},
|
||||||
"cell_type":"markdown",
|
"cell_type":"markdown",
|
||||||
"id":"1f3ecc42",
|
"id":"1f3ecc42",
|
||||||
"metadata": {},
|
"metadata":{
|
||||||
|
|
||||||
|
},
|
||||||
"source":[
|
"source":[
|
||||||
"Now, let's create a vector search index on your cluster. In the below example, `embedding` is the name of the field that contains the embedding vector. Please refer to the [documentation](https://www.mongodb.com/docs/atlas/atlas-search/define-field-mappings-for-vector-search) to get more details on how to define an Atlas Vector Search index.\n",
|
"Now, let's create a vector search index on your cluster. In the below example, `embedding` is the name of the field that contains the embedding vector. Please refer to the [documentation](https://www.mongodb.com/docs/atlas/atlas-search/define-field-mappings-for-vector-search) to get more details on how to define an Atlas Vector Search index.\n",
|
||||||
"You can name the index `langchain_demo` and create the index on the namespace `lanchain_db.langchain_col`. Finally, write the following definition in the JSON editor on MongoDB Atlas:\n",
|
"You can name the index `langchain_demo` and create the index on the namespace `lanchain_db.langchain_col`. Finally, write the following definition in the JSON editor on MongoDB Atlas:\n",
|
||||||
@ -94,23 +116,17 @@
|
|||||||
"execution_count":2,
|
"execution_count":2,
|
||||||
"id":"aac9563e",
|
"id":"aac9563e",
|
||||||
"metadata":{
|
"metadata":{
|
||||||
"tags": []
|
"tags":[
|
||||||
|
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs":[
|
||||||
|
|
||||||
|
],
|
||||||
"source":[
|
"source":[
|
||||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||||
"from langchain.vectorstores import MongoDBAtlasVectorSearch\n",
|
"from langchain.vectorstores import MongoDBAtlasVectorSearch\n",
|
||||||
"from langchain.document_loaders import TextLoader"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"id": "a3c3999a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from langchain.document_loaders import TextLoader\n",
|
"from langchain.document_loaders import TextLoader\n",
|
||||||
"\n",
|
"\n",
|
||||||
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
|
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
|
||||||
@ -125,8 +141,12 @@
|
|||||||
"cell_type":"code",
|
"cell_type":"code",
|
||||||
"execution_count":null,
|
"execution_count":null,
|
||||||
"id":"6e104aee",
|
"id":"6e104aee",
|
||||||
"metadata": {},
|
"metadata":{
|
||||||
"outputs": [],
|
|
||||||
|
},
|
||||||
|
"outputs":[
|
||||||
|
|
||||||
|
],
|
||||||
"source":[
|
"source":[
|
||||||
"from pymongo import MongoClient\n",
|
"from pymongo import MongoClient\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -152,51 +172,48 @@
|
|||||||
"cell_type":"code",
|
"cell_type":"code",
|
||||||
"execution_count":null,
|
"execution_count":null,
|
||||||
"id":"9c608226",
|
"id":"9c608226",
|
||||||
"metadata": {},
|
"metadata":{
|
||||||
"outputs": [],
|
|
||||||
|
},
|
||||||
|
"outputs":[
|
||||||
|
|
||||||
|
],
|
||||||
"source":[
|
"source":[
|
||||||
"print(docs[0].page_content)"
|
"print(docs[0].page_content)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
"attachments":{
|
||||||
|
|
||||||
|
},
|
||||||
"cell_type":"markdown",
|
"cell_type":"markdown",
|
||||||
"id":"851a2ec9-9390-49a4-8412-3e132c9f789d",
|
"id":"851a2ec9-9390-49a4-8412-3e132c9f789d",
|
||||||
"metadata": {},
|
"metadata":{
|
||||||
|
|
||||||
|
},
|
||||||
"source":[
|
"source":[
|
||||||
"You can reuse the vector search index you created, make sure the `OPENAI_API_KEY` environment variable is set up, then execute another query."
|
"You can also instantiate the vector store directly and execute a query as follows:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type":"code",
|
"cell_type":"code",
|
||||||
"execution_count":null,
|
"execution_count":null,
|
||||||
"id":"6336fe79-3e73-48be-b20a-0ff1bb6a4399",
|
"id":"6336fe79-3e73-48be-b20a-0ff1bb6a4399",
|
||||||
"metadata": {},
|
"metadata":{
|
||||||
"outputs": [],
|
|
||||||
|
},
|
||||||
|
"outputs":[
|
||||||
|
|
||||||
|
],
|
||||||
"source":[
|
"source":[
|
||||||
"from pymongo import MongoClient\n",
|
|
||||||
"from langchain.vectorstores import MongoDBAtlasVectorSearch\n",
|
|
||||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
|
||||||
"import os\n",
|
|
||||||
"\n",
|
|
||||||
"MONGODB_ATLAS_URI = os.environ[\"MONGODB_ATLAS_URI\"]\n",
|
|
||||||
"\n",
|
|
||||||
"# initialize MongoDB python client\n",
|
|
||||||
"client = MongoClient(MONGODB_ATLAS_URI)\n",
|
|
||||||
"\n",
|
|
||||||
"db_name = \"langchain_db\"\n",
|
|
||||||
"collection_name = \"langchain_col\"\n",
|
|
||||||
"collection = client[db_name][collection_name]\n",
|
|
||||||
"index_name = \"langchain_demo\"\n",
|
|
||||||
"\n",
|
|
||||||
"# initialize vector store\n",
|
"# initialize vector store\n",
|
||||||
"vectorStore = MongoDBAtlasVectorSearch(\n",
|
"vectorstore = MongoDBAtlasVectorSearch(\n",
|
||||||
" collection, OpenAIEmbeddings(), index_name=index_name\n",
|
" collection, OpenAIEmbeddings(), index_name=index_name\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# perform a similarity search between a query and the ingested documents\n",
|
"# perform a similarity search between a query and the ingested documents\n",
|
||||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||||
"docs = vectorStore.similarity_search(query)\n",
|
"docs = vectorstore.similarity_search(query)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(docs[0].page_content)"
|
"print(docs[0].page_content)"
|
||||||
]
|
]
|
||||||
|
@ -14,9 +14,12 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pymongo.collection import Collection
|
from pymongo.collection import Collection
|
||||||
@ -137,6 +140,39 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|||||||
insert_result = self._collection.insert_many(to_insert)
|
insert_result = self._collection.insert_many(to_insert)
|
||||||
return insert_result.inserted_ids
|
return insert_result.inserted_ids
|
||||||
|
|
||||||
|
def _similarity_search_with_score(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
pre_filter: Optional[dict] = None,
|
||||||
|
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
knn_beta = {
|
||||||
|
"vector": embedding,
|
||||||
|
"path": self._embedding_key,
|
||||||
|
"k": k,
|
||||||
|
}
|
||||||
|
if pre_filter:
|
||||||
|
knn_beta["filter"] = pre_filter
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$search": {
|
||||||
|
"index": self._index_name,
|
||||||
|
"knnBeta": knn_beta,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$set": {"score": {"$meta": "searchScore"}}},
|
||||||
|
]
|
||||||
|
if post_filter_pipeline is not None:
|
||||||
|
pipeline.extend(post_filter_pipeline)
|
||||||
|
cursor = self._collection.aggregate(pipeline)
|
||||||
|
docs = []
|
||||||
|
for res in cursor:
|
||||||
|
text = res.pop(self._text_key)
|
||||||
|
score = res.pop("score")
|
||||||
|
docs.append((Document(page_content=text, metadata=res), score))
|
||||||
|
return docs
|
||||||
|
|
||||||
def similarity_search_with_score(
|
def similarity_search_with_score(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@ -165,30 +201,13 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List of Documents most similar to the query and score for each
|
List of Documents most similar to the query and score for each
|
||||||
"""
|
"""
|
||||||
knn_beta = {
|
embedding = self._embedding.embed_query(query)
|
||||||
"vector": self._embedding.embed_query(query),
|
docs = self._similarity_search_with_score(
|
||||||
"path": self._embedding_key,
|
embedding,
|
||||||
"k": k,
|
k=k,
|
||||||
}
|
pre_filter=pre_filter,
|
||||||
if pre_filter:
|
post_filter_pipeline=post_filter_pipeline,
|
||||||
knn_beta["filter"] = pre_filter
|
)
|
||||||
pipeline = [
|
|
||||||
{
|
|
||||||
"$search": {
|
|
||||||
"index": self._index_name,
|
|
||||||
"knnBeta": knn_beta,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$project": {"score": {"$meta": "searchScore"}, self._embedding_key: 0}},
|
|
||||||
]
|
|
||||||
if post_filter_pipeline is not None:
|
|
||||||
pipeline.extend(post_filter_pipeline)
|
|
||||||
cursor = self._collection.aggregate(pipeline)
|
|
||||||
docs = []
|
|
||||||
for res in cursor:
|
|
||||||
text = res.pop(self._text_key)
|
|
||||||
score = res.pop("score")
|
|
||||||
docs.append((Document(page_content=text, metadata=res), score))
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def similarity_search(
|
def similarity_search(
|
||||||
@ -227,6 +246,53 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|||||||
)
|
)
|
||||||
return [doc for doc, _ in docs_and_scores]
|
return [doc for doc, _ in docs_and_scores]
|
||||||
|
|
||||||
|
def max_marginal_relevance_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
pre_filter: Optional[dict] = None,
|
||||||
|
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Optional Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Optional Number of Documents to fetch before passing to MMR
|
||||||
|
algorithm. Defaults to 20.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
pre_filter: Optional Dictionary of argument(s) to prefilter on document
|
||||||
|
fields.
|
||||||
|
post_filter_pipeline: Optional Pipeline of MongoDB aggregation stages
|
||||||
|
following the knnBeta search.
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
query_embedding = self._embedding.embed_query(query)
|
||||||
|
docs = self._similarity_search_with_score(
|
||||||
|
query_embedding,
|
||||||
|
k=fetch_k,
|
||||||
|
pre_filter=pre_filter,
|
||||||
|
post_filter_pipeline=post_filter_pipeline,
|
||||||
|
)
|
||||||
|
mmr_doc_indexes = maximal_marginal_relevance(
|
||||||
|
np.array(query_embedding),
|
||||||
|
[doc.metadata[self._embedding_key] for doc, _ in docs],
|
||||||
|
k=k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
)
|
||||||
|
mmr_docs = [docs[i][0] for i in mmr_doc_indexes]
|
||||||
|
return mmr_docs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
cls,
|
cls,
|
||||||
@ -252,7 +318,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|||||||
from langchain.vectorstores import MongoDBAtlasVectorSearch
|
from langchain.vectorstores import MongoDBAtlasVectorSearch
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
client = MongoClient("<YOUR-CONNECTION-STRING>")
|
mongo_client = MongoClient("<YOUR-CONNECTION-STRING>")
|
||||||
collection = mongo_client["<db_name>"]["<collection_name>"]
|
collection = mongo_client["<db_name>"]["<collection_name>"]
|
||||||
embeddings = OpenAIEmbeddings()
|
embeddings = OpenAIEmbeddings()
|
||||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||||
@ -264,6 +330,6 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|||||||
"""
|
"""
|
||||||
if collection is None:
|
if collection is None:
|
||||||
raise ValueError("Must provide 'collection' named parameter.")
|
raise ValueError("Must provide 'collection' named parameter.")
|
||||||
vecstore = cls(collection, embedding, **kwargs)
|
vectorstore = cls(collection, embedding, **kwargs)
|
||||||
vecstore.add_texts(texts, metadatas=metadatas)
|
vectorstore.add_texts(texts, metadatas=metadatas)
|
||||||
return vecstore
|
return vectorstore
|
||||||
|
@ -119,3 +119,18 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
"Sandwich", k=1, pre_filter={"range": {"lte": 0, "path": "c"}}
|
"Sandwich", k=1, pre_filter={"range": {"lte": 0, "path": "c"}}
|
||||||
)
|
)
|
||||||
assert output == []
|
assert output == []
|
||||||
|
|
||||||
|
def test_mmr(self, embedding_openai: Embeddings) -> None:
|
||||||
|
texts = ["foo", "foo", "fou", "foy"]
|
||||||
|
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||||
|
texts,
|
||||||
|
embedding_openai,
|
||||||
|
collection=collection,
|
||||||
|
index_name=INDEX_NAME,
|
||||||
|
)
|
||||||
|
sleep(1) # waits for mongot to update Lucene's index
|
||||||
|
query = "foo"
|
||||||
|
output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1)
|
||||||
|
assert len(output) == len(texts)
|
||||||
|
assert output[0].page_content == "foo"
|
||||||
|
assert output[1].page_content != "foo"
|
||||||
|
Loading…
Reference in New Issue
Block a user