mongodb[patch]: Remove embedding retrieval from mongodb payload (#19035)

## Description
Returning the embedding is not necessary in the vector search
functionality unless specified as a debugging step. This change defaults
the behavior such that the server _only_ returns the embedding key if
explicitly requested, such as in the case of
`max_marginal_relevance_search`.


- [x] **Add tests and docs**: If you're adding a new integration, please
include
* Added `test_from_documents_no_embedding_return`


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Jib
2024-03-18 15:43:50 -04:00
committed by GitHub
parent 366ba77459
commit 866d6408af
2 changed files with 64 additions and 11 deletions

View File

@@ -183,6 +183,8 @@ class MongoDBAtlasVectorSearch(VectorStore):
k: int = 4,
pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
include_embedding: bool = False,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
params = {
"queryVector": embedding,
@@ -199,6 +201,11 @@ class MongoDBAtlasVectorSearch(VectorStore):
query,
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
]
# Exclude the embedding key from the return payload
if not include_embedding:
pipeline.append({"$project": {self._embedding_key: 0}})
if post_filter_pipeline is not None:
pipeline.extend(post_filter_pipeline)
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
@@ -215,6 +222,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
k: int = 4,
pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return MongoDB documents most similar to the given query and their scores.
@@ -238,6 +246,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
k=k,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
**kwargs,
)
return docs
@@ -271,6 +280,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
k=k,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
**kwargs,
)
if additional and "similarity_score" in additional:
@@ -310,20 +320,15 @@ class MongoDBAtlasVectorSearch(VectorStore):
List of documents selected by maximal marginal relevance.
"""
query_embedding = self._embedding.embed_query(query)
docs = self._similarity_search_with_score(
query_embedding,
k=fetch_k,
return self.max_marginal_relevance_search_by_vector(
embedding=query_embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
**kwargs,
)
mmr_doc_indexes = maximal_marginal_relevance(
np.array(query_embedding),
[doc.metadata[self._embedding_key] for doc, _ in docs],
k=k,
lambda_mult=lambda_mult,
)
mmr_docs = [docs[i][0] for i in mmr_doc_indexes]
return mmr_docs
@classmethod
def from_texts(
@@ -433,6 +438,8 @@ class MongoDBAtlasVectorSearch(VectorStore):
k=fetch_k,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
include_embedding=kwargs.pop("include_embedding", True),
**kwargs,
)
mmr_doc_indexes = maximal_marginal_relevance(
np.array(embedding),