From a47f69a120cdc337e4ab1dd5e732fe1a714c33ad Mon Sep 17 00:00:00 2001 From: Casey Clements Date: Wed, 17 Jul 2024 13:26:20 -0700 Subject: [PATCH] partners/mongodb : Significant MongoDBVectorSearch ID enhancements (#23535) ## Description This pull-request improves the treatment of document IDs in `MongoDBAtlasVectorSearch`. Class method signatures of add_documents, add_texts, delete, and from_texts now include an `ids:Optional[List[str]]` keyword argument permitting the user greater control. Note that, as before, IDs may also be inferred from `Document.metadata['_id']` if present, but this is no longer required, IDs can also optionally be returned from searches. This PR closes the following JIRA issues. * [PYTHON-4446](https://jira.mongodb.org/browse/PYTHON-4446) MongoDBVectorSearch delete / add_texts function rework * [PYTHON-4435](https://jira.mongodb.org/browse/PYTHON-4435) Add support for "Indexing" * [PYTHON-4534](https://jira.mongodb.org/browse/PYTHON-4534) Ensure datetimes are json-serializable --------- Co-authored-by: Erick Friis --- .../mongodb/langchain_mongodb/utils.py | 59 +++- .../mongodb/langchain_mongodb/vectorstores.py | 175 +++++++++--- .../integration_tests/test_vectorstores.py | 255 ++++++++++++++---- .../tests/unit_tests/test_vectorstores.py | 3 +- 4 files changed, 403 insertions(+), 89 deletions(-) diff --git a/libs/partners/mongodb/langchain_mongodb/utils.py b/libs/partners/mongodb/langchain_mongodb/utils.py index cea4b8c0446..09693bf1483 100644 --- a/libs/partners/mongodb/langchain_mongodb/utils.py +++ b/libs/partners/mongodb/langchain_mongodb/utils.py @@ -8,10 +8,15 @@ are duplicated in this utility respectively from modules: - "libs/community/langchain_community/utils/math.py" """ +from __future__ import annotations + import logging -from typing import List, Union +from datetime import date, datetime +from typing import Any, Dict, List, Union import numpy as np +from bson import ObjectId +from bson.errors import InvalidId logger = logging.getLogger(__name__) @@ -88,3 +93,55 @@ def maximal_marginal_relevance( idxs.append(idx_to_add) selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) return idxs + + +def str_to_oid(str_repr: str) -> ObjectId | str: + """Attempt to cast string representation of id to MongoDB's internal BSON ObjectId. + + To be consistent with ObjectId, input must be a 24 character hex string. + If it is not, MongoDB will happily use the string in the main _id index. + Importantly, the str representation that comes out of MongoDB will have this form. + + Args: + str_repr: id as string. + + Returns: + ObjectID + """ + try: + return ObjectId(str_repr) + except InvalidId: + logger.debug( + "ObjectIds must be 12-character byte or 24-character hex strings. " + "Examples: b'heres12bytes', '6f6e6568656c6c6f68656768'" + ) + return str_repr + + +def oid_to_str(oid: ObjectId) -> str: + """Convert MongoDB's internal BSON ObjectId into a simple str for compatibility. + + Instructive helper to show where data is coming out of MongoDB. + + Args: + oid: bson.ObjectId + + Returns: + 24 character hex string. + """ + return str(oid) + + +def make_serializable( + obj: Dict[str, Any], +) -> None: + """Recursively cast values in a dict to a form able to json.dump""" + for k, v in obj.items(): + if isinstance(v, dict): + make_serializable(v) + elif isinstance(v, list) and v and isinstance(v[0], (ObjectId, date, datetime)): + obj[k] = [oid_to_str(item) for item in v] + elif isinstance(v, ObjectId): + obj[k] = oid_to_str(v) + elif isinstance(v, (datetime, date)): + obj[k] = v.isoformat() diff --git a/libs/partners/mongodb/langchain_mongodb/vectorstores.py b/libs/partners/mongodb/langchain_mongodb/vectorstores.py index 4eafd5abf10..c85bce58b85 100644 --- a/libs/partners/mongodb/langchain_mongodb/vectorstores.py +++ b/libs/partners/mongodb/langchain_mongodb/vectorstores.py @@ -16,7 +16,6 @@ from typing import ( ) import numpy as np -from bson import ObjectId, json_util from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.runnables.config import run_in_executor @@ -30,7 +29,12 @@ from langchain_mongodb.index import ( create_vector_search_index, update_vector_search_index, ) -from langchain_mongodb.utils import maximal_marginal_relevance +from langchain_mongodb.utils import ( + make_serializable, + maximal_marginal_relevance, + oid_to_str, + str_to_oid, +) MongoDBDocumentType = TypeVar("MongoDBDocumentType", bound=Dict[str, Any]) VST = TypeVar("VST", bound=VectorStore) @@ -143,51 +147,153 @@ class MongoDBAtlasVectorSearch(VectorStore): self, texts: Iterable[str], metadatas: Optional[List[Dict[str, Any]]] = None, + ids: Optional[List[str]] = None, **kwargs: Any, ) -> List[str]: - """Run more texts through the embeddings and add to the vectorstore. + """Add texts, create embeddings, and add to the Collection and index. + + Important notes on ids: + - If _id or id is a key in the metadatas dicts, one must + pop them and provide as separate list. + - They must be unique. + - If they are not provided, the VectorStore will create unique ones, + stored as bson.ObjectIds internally, and strings in Langchain. + These will appear in Document.metadata with key, '_id'. Args: texts: Iterable of strings to add to the vectorstore. metadatas: Optional list of metadatas associated with the texts. + ids: Optional list of unique ids that will be used as index in VectorStore. + See note on ids. Returns: - List of ids from adding the texts into the vectorstore. + List of ids added to the vectorstore. """ - batch_size = kwargs.get("batch_size", DEFAULT_INSERT_BATCH_SIZE) - _metadatas: Union[List, Generator] = metadatas or ({} for _ in texts) + + # Check to see if metadata includes ids + if metadatas is not None and ( + metadatas[0].get("_id") or metadatas[0].get("id") + ): + logger.warning( + "_id or id key found in metadata. " + "Please pop from each dict and input as separate list." + "Retrieving methods will include the same id as '_id' in metadata." + ) + texts_batch = texts + _metadatas: Union[List, Generator] = metadatas or ({} for _ in texts) metadatas_batch = _metadatas + result_ids = [] + batch_size = kwargs.get("batch_size", DEFAULT_INSERT_BATCH_SIZE) if batch_size: texts_batch = [] metadatas_batch = [] size = 0 - for i, (text, metadata) in enumerate(zip(texts, _metadatas)): + i = 0 + for j, (text, metadata) in enumerate(zip(texts, _metadatas)): size += len(text) + len(metadata) texts_batch.append(text) metadatas_batch.append(metadata) - if (i + 1) % batch_size == 0 or size >= 47_000_000: - result_ids.extend(self._insert_texts(texts_batch, metadatas_batch)) + if (j + 1) % batch_size == 0 or size >= 47_000_000: + if ids: + batch_res = self.bulk_embed_and_insert_texts( + texts_batch, metadatas_batch, ids[i : j + 1] + ) + else: + batch_res = self.bulk_embed_and_insert_texts( + texts_batch, metadatas_batch + ) + result_ids.extend(batch_res) texts_batch = [] metadatas_batch = [] size = 0 + i = j + 1 if texts_batch: - result_ids.extend(self._insert_texts(texts_batch, metadatas_batch)) # type: ignore - return [str(id) for id in result_ids] + if ids: + batch_res = self.bulk_embed_and_insert_texts( + texts_batch, metadatas_batch, ids[i : j + 1] + ) + else: + batch_res = self.bulk_embed_and_insert_texts( + texts_batch, metadatas_batch + ) + result_ids.extend(batch_res) + return result_ids - def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> List: + def bulk_embed_and_insert_texts( + self, + texts: Union[List[str], Iterable[str]], + metadatas: Union[List[dict], Generator[dict, Any, Any]], + ids: Optional[List[str]] = None, + ) -> List[str]: + """Bulk insert single batch of texts, embeddings, and optionally ids. + + See add_texts for additional details. + """ if not texts: return [] - # Embed and create the documents - embeddings = self._embedding.embed_documents(texts) - to_insert = [ - {self._text_key: t, self._embedding_key: embedding, **m} - for t, m, embedding in zip(texts, metadatas, embeddings) - ] + # Compute embedding vectors + embeddings = self._embedding.embed_documents(texts) # type: ignore + if ids: + to_insert = [ + { + "_id": str_to_oid(i), + self._text_key: t, + self._embedding_key: embedding, + **m, + } + for i, t, m, embedding in zip(ids, texts, metadatas, embeddings) + ] + else: + to_insert = [ + {self._text_key: t, self._embedding_key: embedding, **m} + for t, m, embedding in zip(texts, metadatas, embeddings) + ] # insert the documents in MongoDB Atlas insert_result = self._collection.insert_many(to_insert) # type: ignore - return insert_result.inserted_ids + return [oid_to_str(_id) for _id in insert_result.inserted_ids] + + def add_documents( + self, + documents: List[Document], + ids: Optional[List[str]] = None, + batch_size: int = DEFAULT_INSERT_BATCH_SIZE, + **kwargs: Any, + ) -> List[str]: + """Add documents to the vectorstore. + + Args: + documents: Documents to add to the vectorstore. + ids: Optional list of unique ids that will be used as index in VectorStore. + See note on ids in add_texts. + batch_size: Number of documents to insert at a time. + Tuning this may help with performance and sidestep MongoDB limits. + + Returns: + List of IDs of the added texts. + """ + n_docs = len(documents) + if ids: + assert len(ids) == n_docs, "Number of ids must equal number of documents." + result_ids = [] + start = 0 + for end in range(batch_size, n_docs + batch_size, batch_size): + texts, metadatas = zip( + *[(doc.page_content, doc.metadata) for doc in documents[start:end]] + ) + if ids: + result_ids.extend( + self.bulk_embed_and_insert_texts( + texts=texts, metadatas=metadatas, ids=ids[start:end] + ) + ) + else: + result_ids.extend( + self.bulk_embed_and_insert_texts(texts=texts, metadatas=metadatas) + ) + start = end + return result_ids def _similarity_search_with_score( self, @@ -196,8 +302,10 @@ class MongoDBAtlasVectorSearch(VectorStore): pre_filter: Optional[Dict] = None, post_filter_pipeline: Optional[List[Dict]] = None, include_embedding: bool = False, + include_ids: bool = False, **kwargs: Any, ) -> List[Tuple[Document, float]]: + """Core implementation.""" params = { "queryVector": embedding, "path": self._embedding_key, @@ -223,22 +331,10 @@ class MongoDBAtlasVectorSearch(VectorStore): cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type] docs = [] - def _make_serializable(obj: Dict[str, Any]) -> None: - for k, v in obj.items(): - if isinstance(v, dict): - _make_serializable(v) - elif isinstance(v, list) and v and isinstance(v[0], ObjectId): - obj[k] = [json_util.default(item) for item in v] - elif isinstance(v, ObjectId): - obj[k] = json_util.default(v) - for res in cursor: text = res.pop(self._text_key) score = res.pop("score") - # Make every ObjectId found JSON-Serializable - # following format used in bson.json_util.loads - # e.g. loads('{"_id": {"$oid": "664..."}}') == {'_id': ObjectId('664..')} # noqa: E501 - _make_serializable(res) + make_serializable(res) docs.append((Document(page_content=text, metadata=res), score)) return docs @@ -363,6 +459,7 @@ class MongoDBAtlasVectorSearch(VectorStore): embedding: Embeddings, metadatas: Optional[List[Dict]] = None, collection: Optional[Collection[MongoDBDocumentType]] = None, + ids: Optional[List[str]] = None, **kwargs: Any, ) -> MongoDBAtlasVectorSearch: """Construct a `MongoDB Atlas Vector Search` vector store from raw documents. @@ -394,25 +491,25 @@ class MongoDBAtlasVectorSearch(VectorStore): if collection is None: raise ValueError("Must provide 'collection' named parameter.") vectorstore = cls(collection, embedding, **kwargs) - vectorstore.add_texts(texts, metadatas=metadatas) + vectorstore.add_texts(texts=texts, metadatas=metadatas, ids=ids, **kwargs) return vectorstore def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: - """Delete by ObjectId or other criteria. + """Delete documents from VectorStore by ids. Args: ids: List of ids to delete. - **kwargs: Other keyword arguments that subclasses might use. + **kwargs: Other keyword arguments passed to Collection.delete_many() Returns: Optional[bool]: True if deletion is successful, False otherwise, None if not implemented. """ - search_params: dict[str, Any] = {} + filter = {} if ids: - search_params["_id"] = {"$in": [ObjectId(id) for id in ids]} - - return self._collection.delete_many({**search_params, **kwargs}).acknowledged + oids = [str_to_oid(i) for i in ids] + filter = {"_id": {"$in": oids}} + return self._collection.delete_many(filter=filter, **kwargs).acknowledged async def adelete( self, ids: Optional[List[str]] = None, **kwargs: Any diff --git a/libs/partners/mongodb/tests/integration_tests/test_vectorstores.py b/libs/partners/mongodb/tests/integration_tests/test_vectorstores.py index e5ac536e012..c5fb6d26086 100644 --- a/libs/partners/mongodb/tests/integration_tests/test_vectorstores.py +++ b/libs/partners/mongodb/tests/integration_tests/test_vectorstores.py @@ -4,9 +4,10 @@ from __future__ import annotations import os from time import monotonic, sleep -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Generator, Iterable, List, Optional, Union import pytest # type: ignore[import-not-found] +from bson import ObjectId from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from pymongo import MongoClient @@ -15,7 +16,9 @@ from pymongo.errors import OperationFailure from langchain_mongodb import MongoDBAtlasVectorSearch from langchain_mongodb.index import drop_vector_search_index -from tests.utils import ConsistentFakeEmbeddings +from langchain_mongodb.utils import oid_to_str + +from ..utils import ConsistentFakeEmbeddings INDEX_NAME = "langchain-test-index-vectorstores" INDEX_CREATION_NAME = "langchain-test-index-vectorstores-create-test" @@ -25,20 +28,25 @@ DB_NAME, COLLECTION_NAME = NAMESPACE.split(".") INDEX_COLLECTION_NAME = "langchain_test_vectorstores_index" INDEX_DB_NAME = "langchain_test_index_db" DIMENSIONS = 1536 -TIMEOUT = 10.0 +TIMEOUT = 120.0 INTERVAL = 0.5 class PatchedMongoDBAtlasVectorSearch(MongoDBAtlasVectorSearch): - def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> List: + def bulk_embed_and_insert_texts( + self, + texts: Union[List[str], Iterable[str]], + metadatas: Union[List[dict], Generator[dict, Any, Any]], + ids: Optional[List[str]] = None, + ) -> List: """Patched insert_texts that waits for data to be indexed before returning""" - ids = super()._insert_texts(texts, metadatas) + ids_inserted = super().bulk_embed_and_insert_texts(texts, metadatas, ids) start = monotonic() - while len(ids) != self.similarity_search("sandwich") and ( + while len(ids_inserted) != len(self.similarity_search("sandwich")) and ( monotonic() - start <= TIMEOUT ): sleep(INTERVAL) - return ids + return ids_inserted def create_vector_search_index( self, @@ -87,6 +95,16 @@ def collection() -> Collection: return get_collection() +@pytest.fixture +def texts() -> List[str]: + return [ + "Dogs are tough.", + "Cats have fluff.", + "What is a sandwich?", + "That fence is purple.", + ] + + @pytest.fixture() def index_collection() -> Collection: return get_collection(INDEX_DB_NAME, INDEX_COLLECTION_NAME) @@ -118,12 +136,10 @@ class TestMongoDBAtlasVectorSearch: ) @pytest.fixture - def embedding_openai(self) -> Embeddings: + def embeddings(self) -> Embeddings: return ConsistentFakeEmbeddings(DIMENSIONS) - def test_from_documents( - self, embedding_openai: Embeddings, collection: Any - ) -> None: + def test_from_documents(self, embeddings: Embeddings, collection: Any) -> None: """Test end to end construction and search.""" documents = [ Document(page_content="Dogs are tough.", metadata={"a": 1}), @@ -133,7 +149,7 @@ class TestMongoDBAtlasVectorSearch: ] vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents( documents, - embedding_openai, + embeddings, collection=collection, index_name=INDEX_NAME, ) @@ -143,7 +159,7 @@ class TestMongoDBAtlasVectorSearch: assert any([key.page_content == output[0].page_content for key in documents]) def test_from_documents_no_embedding_return( - self, embedding_openai: Embeddings, collection: Any + self, embeddings: Embeddings, collection: Any ) -> None: """Test end to end construction and search.""" documents = [ @@ -154,7 +170,7 @@ class TestMongoDBAtlasVectorSearch: ] vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents( documents, - embedding_openai, + embeddings, collection=collection, index_name=INDEX_NAME, ) @@ -166,7 +182,7 @@ class TestMongoDBAtlasVectorSearch: assert any([key.page_content == output[0].page_content for key in documents]) def test_from_documents_embedding_return( - self, embedding_openai: Embeddings, collection: Any + self, embeddings: Embeddings, collection: Any ) -> None: """Test end to end construction and search.""" documents = [ @@ -177,7 +193,7 @@ class TestMongoDBAtlasVectorSearch: ] vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents( documents, - embedding_openai, + embeddings, collection=collection, index_name=INDEX_NAME, ) @@ -188,16 +204,12 @@ class TestMongoDBAtlasVectorSearch: # Check for the presence of the metadata key assert any([key.page_content == output[0].page_content for key in documents]) - def test_from_texts(self, embedding_openai: Embeddings, collection: Any) -> None: - texts = [ - "Dogs are tough.", - "Cats have fluff.", - "What is a sandwich?", - "That fence is purple.", - ] + def test_from_texts( + self, embeddings: Embeddings, collection: Collection, texts: List[str] + ) -> None: vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts( texts, - embedding_openai, + embeddings, collection=collection, index_name=INDEX_NAME, ) @@ -205,19 +217,16 @@ class TestMongoDBAtlasVectorSearch: assert len(output) == 1 def test_from_texts_with_metadatas( - self, embedding_openai: Embeddings, collection: Any + self, + embeddings: Embeddings, + collection: Collection, + texts: List[str], ) -> None: - texts = [ - "Dogs are tough.", - "Cats have fluff.", - "What is a sandwich?", - "The fence is purple.", - ] metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}] metakeys = ["a", "b", "c", "d", "e"] vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts( texts, - embedding_openai, + embeddings, metadatas=metadatas, collection=collection, index_name=INDEX_NAME, @@ -228,18 +237,12 @@ class TestMongoDBAtlasVectorSearch: assert any([key in output[0].metadata for key in metakeys]) def test_from_texts_with_metadatas_and_pre_filter( - self, embedding_openai: Embeddings, collection: Any + self, embeddings: Embeddings, collection: Any, texts: List[str] ) -> None: - texts = [ - "Dogs are tough.", - "Cats have fluff.", - "What is a sandwich?", - "The fence is purple.", - ] metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}] vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts( texts, - embedding_openai, + embeddings, metadatas=metadatas, collection=collection, index_name=INDEX_NAME, @@ -249,11 +252,11 @@ class TestMongoDBAtlasVectorSearch: ) assert output == [] - def test_mmr(self, embedding_openai: Embeddings, collection: Any) -> None: + def test_mmr(self, embeddings: Embeddings, collection: Any) -> None: texts = ["foo", "foo", "fou", "foy"] vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts( texts, - embedding_openai, + embeddings, collection=collection, index_name=INDEX_NAME, ) @@ -263,19 +266,175 @@ class TestMongoDBAtlasVectorSearch: assert output[0].page_content == "foo" assert output[1].page_content != "foo" + def test_delete( + self, embeddings: Embeddings, collection: Any, texts: List[str] + ) -> None: + vectorstore = MongoDBAtlasVectorSearch( # PatchedMongoDBAtlasVectorSearch( + collection=collection, + embedding=embeddings, + index_name=INDEX_NAME, + ) + clxn: Collection = vectorstore._collection + assert clxn.count_documents({}) == 0 + ids = vectorstore.add_texts(texts) + assert clxn.count_documents({}) == len(texts) + + deleted = vectorstore.delete(ids[-2:]) + assert deleted + assert clxn.count_documents({}) == len(texts) - 2 + + new_ids = vectorstore.add_texts(["Pigs eat stuff", "Pigs eat sandwiches"]) + assert set(new_ids).intersection(set(ids)) == set() # new ids will be unique. + assert isinstance(new_ids, list) + assert all(isinstance(i, str) for i in new_ids) + assert len(new_ids) == 2 + assert clxn.count_documents({}) == 4 + + def test_add_texts( + self, + embeddings: Embeddings, + collection: Collection, + texts: List[str], + ) -> None: + """Tests API of add_texts, focussing on id treatment""" + metadatas: List[Dict[str, Any]] = [ + {"a": 1}, + {"b": 1}, + {"c": 1}, + {"d": 1, "e": 2}, + ] + + vectorstore = PatchedMongoDBAtlasVectorSearch( + collection=collection, embedding=embeddings, index_name=INDEX_NAME + ) + + # Case 1. Add texts without ids + provided_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas) + all_docs = list(vectorstore._collection.find({})) + assert all("_id" in doc for doc in all_docs) + docids = set(doc["_id"] for doc in all_docs) + assert all(isinstance(_id, ObjectId) for _id in docids) # + assert set(provided_ids) == set(oid_to_str(oid) for oid in docids) + + # Case 2: Test Document.metadata looks right. i.e. contains _id + search_res = vectorstore.similarity_search_with_score("sandwich", k=1) + doc, score = search_res[0] + assert "_id" in doc.metadata + + # Case 3: Add new ids that are 24-char hex strings + hex_ids = [oid_to_str(ObjectId()) for _ in range(2)] + hex_texts = ["Text for hex_id"] * len(hex_ids) + out_ids = vectorstore.add_texts(texts=hex_texts, ids=hex_ids) + assert set(out_ids) == set(hex_ids) + assert collection.count_documents({}) == len(texts) + len(hex_texts) + assert all( + isinstance(doc["_id"], ObjectId) for doc in vectorstore._collection.find({}) + ) + + # Case 4: Add new ids that cannot be cast to ObjectId + # - We can still index and search on them + str_ids = ["Sandwiches are beautiful,", "..sandwiches are fine."] + str_texts = str_ids # No reason for them to differ + out_ids = vectorstore.add_texts(texts=str_texts, ids=str_ids) + assert set(out_ids) == set(str_ids) + assert collection.count_documents({}) == 8 + res = vectorstore.similarity_search("sandwich", k=8) + assert any(str_ids[0] in doc.metadata["_id"] for doc in res) + + # Case 5: Test adding in multiple batches + batch_size = 2 + batch_ids = [oid_to_str(ObjectId()) for _ in range(2 * batch_size)] + batch_texts = [f"Text for batch text {i}" for i in range(2 * batch_size)] + out_ids = vectorstore.add_texts( + texts=batch_texts, ids=batch_ids, batch_size=batch_size + ) + assert set(out_ids) == set(batch_ids) + assert collection.count_documents({}) == 12 + + # Case 6: _ids in metadata + collection.delete_many({}) + # 6a. Unique _id in metadata, but ids=None + # Will be added as if ids kwarg provided + i = 0 + n = len(texts) + assert len(metadatas) == n + _ids = [str(i) for i in range(n)] + for md in metadatas: + md["_id"] = _ids[i] + i += 1 + returned_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas) + assert returned_ids == ["0", "1", "2", "3"] + assert set(d["_id"] for d in vectorstore._collection.find({})) == set(_ids) + + # 6b. Unique "id", not "_id", but ids=None + # New ids will be assigned + i = 1 + for md in metadatas: + md.pop("_id") + md["id"] = f"{1}" + i += 1 + returned_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas) + assert len(set(returned_ids).intersection(set(_ids))) == 0 + + def test_add_documents( + self, + embeddings: Embeddings, + collection: Collection, + index_name: str = INDEX_NAME, + ) -> None: + """Tests add_documents.""" + vectorstore = PatchedMongoDBAtlasVectorSearch( + collection=collection, embedding=embeddings, index_name=INDEX_NAME + ) + + # Case 1: No ids + n_docs = 10 + batch_size = 3 + docs = [ + Document(page_content=f"document {i}", metadata={"i": i}) + for i in range(n_docs) + ] + result_ids = vectorstore.add_documents(docs, batch_size=batch_size) + assert len(result_ids) == n_docs + assert collection.count_documents({}) == n_docs + + # Case 2: ids + collection.delete_many({}) + n_docs = 10 + batch_size = 3 + docs = [ + Document(page_content=f"document {i}", metadata={"i": i}) + for i in range(n_docs) + ] + ids = [str(i) for i in range(n_docs)] + result_ids = vectorstore.add_documents(docs, ids, batch_size=batch_size) + assert len(result_ids) == n_docs + assert set(ids) == set(collection.distinct("_id")) + + # Case 3: Single batch + collection.delete_many({}) + n_docs = 3 + batch_size = 10 + docs = [ + Document(page_content=f"document {i}", metadata={"i": i}) + for i in range(n_docs) + ] + ids = [str(i) for i in range(n_docs)] + result_ids = vectorstore.add_documents(docs, ids, batch_size=batch_size) + assert len(result_ids) == n_docs + assert set(ids) == set(collection.distinct("_id")) + def test_index_creation( - self, embedding_openai: Embeddings, index_collection: Any + self, embeddings: Embeddings, index_collection: Any ) -> None: vectorstore = PatchedMongoDBAtlasVectorSearch( - index_collection, embedding_openai, index_name=INDEX_CREATION_NAME + index_collection, embeddings, index_name=INDEX_CREATION_NAME ) vectorstore.create_vector_search_index(dimensions=1536) - def test_index_update( - self, embedding_openai: Embeddings, index_collection: Any - ) -> None: + def test_index_update(self, embeddings: Embeddings, index_collection: Any) -> None: vectorstore = PatchedMongoDBAtlasVectorSearch( - index_collection, embedding_openai, index_name=INDEX_CREATION_NAME + index_collection, embeddings, index_name=INDEX_CREATION_NAME ) vectorstore.create_vector_search_index(dimensions=1536) vectorstore.create_vector_search_index(dimensions=1536, update=True) diff --git a/libs/partners/mongodb/tests/unit_tests/test_vectorstores.py b/libs/partners/mongodb/tests/unit_tests/test_vectorstores.py index 6e256562759..0ac6898f5ba 100644 --- a/libs/partners/mongodb/tests/unit_tests/test_vectorstores.py +++ b/libs/partners/mongodb/tests/unit_tests/test_vectorstores.py @@ -8,6 +8,7 @@ from langchain_core.embeddings import Embeddings from pymongo.collection import Collection from langchain_mongodb import MongoDBAtlasVectorSearch +from langchain_mongodb.utils import str_to_oid from tests.utils import ConsistentFakeEmbeddings, MockCollection INDEX_NAME = "langchain-test-index" @@ -81,7 +82,7 @@ class TestMongoDBAtlasVectorSearch: assert loads(dumps(output[0].page_content)) == output[0].page_content assert loads(dumps(output[0].metadata)) == output[0].metadata json_metadata = dumps(output[0].metadata) # normal json.dumps - assert isinstance(json_util.loads(json_metadata)["_id"], ObjectId) + isinstance(str_to_oid(json_util.loads(json_metadata)["_id"]), ObjectId) def test_from_documents( self, embedding_openai: Embeddings, collection: MockCollection