mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-12 14:23:58 +00:00
adding max_marginal_relevance_search method to MongoDBAtlasVectorSearch (#7310)
Adding a maximal_marginal_relevance method to the MongoDBAtlasVectorSearch vectorstore enhances the user experience by providing more diverse search results Issue: #7304
This commit is contained in:
parent
04cddfba0d
commit
d2cf0d16b3
@ -1,10 +1,14 @@
|
||||
{
|
||||
"cells":[
|
||||
{
|
||||
"attachments": {},
|
||||
"attachments":{
|
||||
|
||||
},
|
||||
"cell_type":"markdown",
|
||||
"id":"683953b3",
|
||||
"metadata": {},
|
||||
"metadata":{
|
||||
|
||||
},
|
||||
"source":[
|
||||
"# MongoDB Atlas\n",
|
||||
"\n",
|
||||
@ -23,9 +27,13 @@
|
||||
"execution_count":null,
|
||||
"id":"b4c41cad-08ef-4f72-a545-2151e4598efe",
|
||||
"metadata":{
|
||||
"tags": []
|
||||
"tags":[
|
||||
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs":[
|
||||
|
||||
],
|
||||
"source":[
|
||||
"!pip install pymongo"
|
||||
]
|
||||
@ -34,21 +42,28 @@
|
||||
"cell_type":"code",
|
||||
"execution_count":null,
|
||||
"id":"c1e38361-c1fe-4ac6-86e9-c90ebaf7ae87",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"metadata":{
|
||||
|
||||
},
|
||||
"outputs":[
|
||||
|
||||
],
|
||||
"source":[
|
||||
"import os\n",
|
||||
"import getpass\n",
|
||||
"\n",
|
||||
"MONGODB_ATLAS_CLUSTER_URI = getpass.getpass(\"MongoDB Atlas Cluster URI:\")\n",
|
||||
"MONGODB_ATLAS_CLUSTER_URI = os.environ[\"MONGODB_ATLAS_CLUSTER_URI\"]"
|
||||
"MONGODB_ATLAS_CLUSTER_URI = getpass.getpass(\"MongoDB Atlas Cluster URI:\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"attachments":{
|
||||
|
||||
},
|
||||
"cell_type":"markdown",
|
||||
"id":"457ace44-1d95-4001-9dd5-78811ab208ad",
|
||||
"metadata": {},
|
||||
"metadata":{
|
||||
|
||||
},
|
||||
"source":[
|
||||
"We want to use `OpenAIEmbeddings` so we need to set up our OpenAI API Key. "
|
||||
]
|
||||
@ -57,18 +72,25 @@
|
||||
"cell_type":"code",
|
||||
"execution_count":null,
|
||||
"id":"2d8f240d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"metadata":{
|
||||
|
||||
},
|
||||
"outputs":[
|
||||
|
||||
],
|
||||
"source":[
|
||||
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")\n",
|
||||
"OPENAI_API_KEY = os.environ[\"OPENAI_API_KEY\"]"
|
||||
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"attachments":{
|
||||
|
||||
},
|
||||
"cell_type":"markdown",
|
||||
"id":"1f3ecc42",
|
||||
"metadata": {},
|
||||
"metadata":{
|
||||
|
||||
},
|
||||
"source":[
|
||||
"Now, let's create a vector search index on your cluster. In the below example, `embedding` is the name of the field that contains the embedding vector. Please refer to the [documentation](https://www.mongodb.com/docs/atlas/atlas-search/define-field-mappings-for-vector-search) to get more details on how to define an Atlas Vector Search index.\n",
|
||||
"You can name the index `langchain_demo` and create the index on the namespace `lanchain_db.langchain_col`. Finally, write the following definition in the JSON editor on MongoDB Atlas:\n",
|
||||
@ -94,23 +116,17 @@
|
||||
"execution_count":2,
|
||||
"id":"aac9563e",
|
||||
"metadata":{
|
||||
"tags": []
|
||||
"tags":[
|
||||
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs":[
|
||||
|
||||
],
|
||||
"source":[
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.vectorstores import MongoDBAtlasVectorSearch\n",
|
||||
"from langchain.document_loaders import TextLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "a3c3999a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"\n",
|
||||
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
|
||||
@ -125,8 +141,12 @@
|
||||
"cell_type":"code",
|
||||
"execution_count":null,
|
||||
"id":"6e104aee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"metadata":{
|
||||
|
||||
},
|
||||
"outputs":[
|
||||
|
||||
],
|
||||
"source":[
|
||||
"from pymongo import MongoClient\n",
|
||||
"\n",
|
||||
@ -152,51 +172,48 @@
|
||||
"cell_type":"code",
|
||||
"execution_count":null,
|
||||
"id":"9c608226",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"metadata":{
|
||||
|
||||
},
|
||||
"outputs":[
|
||||
|
||||
],
|
||||
"source":[
|
||||
"print(docs[0].page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"attachments":{
|
||||
|
||||
},
|
||||
"cell_type":"markdown",
|
||||
"id":"851a2ec9-9390-49a4-8412-3e132c9f789d",
|
||||
"metadata": {},
|
||||
"metadata":{
|
||||
|
||||
},
|
||||
"source":[
|
||||
"You can reuse the vector search index you created, make sure the `OPENAI_API_KEY` environment variable is set up, then execute another query."
|
||||
"You can also instantiate the vector store directly and execute a query as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type":"code",
|
||||
"execution_count":null,
|
||||
"id":"6336fe79-3e73-48be-b20a-0ff1bb6a4399",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"metadata":{
|
||||
|
||||
},
|
||||
"outputs":[
|
||||
|
||||
],
|
||||
"source":[
|
||||
"from pymongo import MongoClient\n",
|
||||
"from langchain.vectorstores import MongoDBAtlasVectorSearch\n",
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"MONGODB_ATLAS_URI = os.environ[\"MONGODB_ATLAS_URI\"]\n",
|
||||
"\n",
|
||||
"# initialize MongoDB python client\n",
|
||||
"client = MongoClient(MONGODB_ATLAS_URI)\n",
|
||||
"\n",
|
||||
"db_name = \"langchain_db\"\n",
|
||||
"collection_name = \"langchain_col\"\n",
|
||||
"collection = client[db_name][collection_name]\n",
|
||||
"index_name = \"langchain_demo\"\n",
|
||||
"\n",
|
||||
"# initialize vector store\n",
|
||||
"vectorStore = MongoDBAtlasVectorSearch(\n",
|
||||
"vectorstore = MongoDBAtlasVectorSearch(\n",
|
||||
" collection, OpenAIEmbeddings(), index_name=index_name\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# perform a similarity search between a query and the ingested documents\n",
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"docs = vectorStore.similarity_search(query)\n",
|
||||
"docs = vectorstore.similarity_search(query)\n",
|
||||
"\n",
|
||||
"print(docs[0].page_content)"
|
||||
]
|
||||
|
@ -14,9 +14,12 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.collection import Collection
|
||||
@ -137,6 +140,39 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
insert_result = self._collection.insert_many(to_insert)
|
||||
return insert_result.inserted_ids
|
||||
|
||||
def _similarity_search_with_score(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
pre_filter: Optional[dict] = None,
|
||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
knn_beta = {
|
||||
"vector": embedding,
|
||||
"path": self._embedding_key,
|
||||
"k": k,
|
||||
}
|
||||
if pre_filter:
|
||||
knn_beta["filter"] = pre_filter
|
||||
pipeline = [
|
||||
{
|
||||
"$search": {
|
||||
"index": self._index_name,
|
||||
"knnBeta": knn_beta,
|
||||
}
|
||||
},
|
||||
{"$set": {"score": {"$meta": "searchScore"}}},
|
||||
]
|
||||
if post_filter_pipeline is not None:
|
||||
pipeline.extend(post_filter_pipeline)
|
||||
cursor = self._collection.aggregate(pipeline)
|
||||
docs = []
|
||||
for res in cursor:
|
||||
text = res.pop(self._text_key)
|
||||
score = res.pop("score")
|
||||
docs.append((Document(page_content=text, metadata=res), score))
|
||||
return docs
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
@ -165,30 +201,13 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
Returns:
|
||||
List of Documents most similar to the query and score for each
|
||||
"""
|
||||
knn_beta = {
|
||||
"vector": self._embedding.embed_query(query),
|
||||
"path": self._embedding_key,
|
||||
"k": k,
|
||||
}
|
||||
if pre_filter:
|
||||
knn_beta["filter"] = pre_filter
|
||||
pipeline = [
|
||||
{
|
||||
"$search": {
|
||||
"index": self._index_name,
|
||||
"knnBeta": knn_beta,
|
||||
}
|
||||
},
|
||||
{"$project": {"score": {"$meta": "searchScore"}, self._embedding_key: 0}},
|
||||
]
|
||||
if post_filter_pipeline is not None:
|
||||
pipeline.extend(post_filter_pipeline)
|
||||
cursor = self._collection.aggregate(pipeline)
|
||||
docs = []
|
||||
for res in cursor:
|
||||
text = res.pop(self._text_key)
|
||||
score = res.pop("score")
|
||||
docs.append((Document(page_content=text, metadata=res), score))
|
||||
embedding = self._embedding.embed_query(query)
|
||||
docs = self._similarity_search_with_score(
|
||||
embedding,
|
||||
k=k,
|
||||
pre_filter=pre_filter,
|
||||
post_filter_pipeline=post_filter_pipeline,
|
||||
)
|
||||
return docs
|
||||
|
||||
def similarity_search(
|
||||
@ -227,6 +246,53 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
pre_filter: Optional[dict] = None,
|
||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Optional Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Optional Number of Documents to fetch before passing to MMR
|
||||
algorithm. Defaults to 20.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
pre_filter: Optional Dictionary of argument(s) to prefilter on document
|
||||
fields.
|
||||
post_filter_pipeline: Optional Pipeline of MongoDB aggregation stages
|
||||
following the knnBeta search.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
query_embedding = self._embedding.embed_query(query)
|
||||
docs = self._similarity_search_with_score(
|
||||
query_embedding,
|
||||
k=fetch_k,
|
||||
pre_filter=pre_filter,
|
||||
post_filter_pipeline=post_filter_pipeline,
|
||||
)
|
||||
mmr_doc_indexes = maximal_marginal_relevance(
|
||||
np.array(query_embedding),
|
||||
[doc.metadata[self._embedding_key] for doc, _ in docs],
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
mmr_docs = [docs[i][0] for i in mmr_doc_indexes]
|
||||
return mmr_docs
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
@ -252,7 +318,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
from langchain.vectorstores import MongoDBAtlasVectorSearch
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
client = MongoClient("<YOUR-CONNECTION-STRING>")
|
||||
mongo_client = MongoClient("<YOUR-CONNECTION-STRING>")
|
||||
collection = mongo_client["<db_name>"]["<collection_name>"]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||
@ -264,6 +330,6 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
"""
|
||||
if collection is None:
|
||||
raise ValueError("Must provide 'collection' named parameter.")
|
||||
vecstore = cls(collection, embedding, **kwargs)
|
||||
vecstore.add_texts(texts, metadatas=metadatas)
|
||||
return vecstore
|
||||
vectorstore = cls(collection, embedding, **kwargs)
|
||||
vectorstore.add_texts(texts, metadatas=metadatas)
|
||||
return vectorstore
|
||||
|
@ -119,3 +119,18 @@ class TestMongoDBAtlasVectorSearch:
|
||||
"Sandwich", k=1, pre_filter={"range": {"lte": 0, "path": "c"}}
|
||||
)
|
||||
assert output == []
|
||||
|
||||
def test_mmr(self, embedding_openai: Embeddings) -> None:
|
||||
texts = ["foo", "foo", "fou", "foy"]
|
||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
sleep(1) # waits for mongot to update Lucene's index
|
||||
query = "foo"
|
||||
output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1)
|
||||
assert len(output) == len(texts)
|
||||
assert output[0].page_content == "foo"
|
||||
assert output[1].page_content != "foo"
|
||||
|
Loading…
Reference in New Issue
Block a user