partners/mongodb : Significant MongoDBVectorSearch ID enhancements (#23535)

## Description

This pull-request improves the treatment of document IDs in
`MongoDBAtlasVectorSearch`.

Class method signatures of add_documents, add_texts, delete, and
from_texts
now include an `ids:Optional[List[str]]` keyword argument permitting the
user
greater control. 
Note that, as before, IDs may also be inferred from
`Document.metadata['_id']`
if present, but this is no longer required,
IDs can also optionally be returned from searches.

This PR closes the following JIRA issues.

* [PYTHON-4446](https://jira.mongodb.org/browse/PYTHON-4446)
MongoDBVectorSearch delete / add_texts function rework
* [PYTHON-4435](https://jira.mongodb.org/browse/PYTHON-4435) Add support
for "Indexing"
* [PYTHON-4534](https://jira.mongodb.org/browse/PYTHON-4534) Ensure
datetimes are json-serializable

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Casey Clements 2024-07-17 13:26:20 -07:00 committed by GitHub
parent cc2cbfabfc
commit a47f69a120
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 403 additions and 89 deletions

View File

@ -8,10 +8,15 @@ are duplicated in this utility respectively from modules:
- "libs/community/langchain_community/utils/math.py"
"""
from __future__ import annotations
import logging
from typing import List, Union
from datetime import date, datetime
from typing import Any, Dict, List, Union
import numpy as np
from bson import ObjectId
from bson.errors import InvalidId
logger = logging.getLogger(__name__)
@ -88,3 +93,55 @@ def maximal_marginal_relevance(
idxs.append(idx_to_add)
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
return idxs
def str_to_oid(str_repr: str) -> ObjectId | str:
"""Attempt to cast string representation of id to MongoDB's internal BSON ObjectId.
To be consistent with ObjectId, input must be a 24 character hex string.
If it is not, MongoDB will happily use the string in the main _id index.
Importantly, the str representation that comes out of MongoDB will have this form.
Args:
str_repr: id as string.
Returns:
ObjectID
"""
try:
return ObjectId(str_repr)
except InvalidId:
logger.debug(
"ObjectIds must be 12-character byte or 24-character hex strings. "
"Examples: b'heres12bytes', '6f6e6568656c6c6f68656768'"
)
return str_repr
def oid_to_str(oid: ObjectId) -> str:
"""Convert MongoDB's internal BSON ObjectId into a simple str for compatibility.
Instructive helper to show where data is coming out of MongoDB.
Args:
oid: bson.ObjectId
Returns:
24 character hex string.
"""
return str(oid)
def make_serializable(
obj: Dict[str, Any],
) -> None:
"""Recursively cast values in a dict to a form able to json.dump"""
for k, v in obj.items():
if isinstance(v, dict):
make_serializable(v)
elif isinstance(v, list) and v and isinstance(v[0], (ObjectId, date, datetime)):
obj[k] = [oid_to_str(item) for item in v]
elif isinstance(v, ObjectId):
obj[k] = oid_to_str(v)
elif isinstance(v, (datetime, date)):
obj[k] = v.isoformat()

View File

@ -16,7 +16,6 @@ from typing import (
)
import numpy as np
from bson import ObjectId, json_util
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
@ -30,7 +29,12 @@ from langchain_mongodb.index import (
create_vector_search_index,
update_vector_search_index,
)
from langchain_mongodb.utils import maximal_marginal_relevance
from langchain_mongodb.utils import (
make_serializable,
maximal_marginal_relevance,
oid_to_str,
str_to_oid,
)
MongoDBDocumentType = TypeVar("MongoDBDocumentType", bound=Dict[str, Any])
VST = TypeVar("VST", bound=VectorStore)
@ -143,51 +147,153 @@ class MongoDBAtlasVectorSearch(VectorStore):
self,
texts: Iterable[str],
metadatas: Optional[List[Dict[str, Any]]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
"""Add texts, create embeddings, and add to the Collection and index.
Important notes on ids:
- If _id or id is a key in the metadatas dicts, one must
pop them and provide as separate list.
- They must be unique.
- If they are not provided, the VectorStore will create unique ones,
stored as bson.ObjectIds internally, and strings in Langchain.
These will appear in Document.metadata with key, '_id'.
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
ids: Optional list of unique ids that will be used as index in VectorStore.
See note on ids.
Returns:
List of ids from adding the texts into the vectorstore.
List of ids added to the vectorstore.
"""
batch_size = kwargs.get("batch_size", DEFAULT_INSERT_BATCH_SIZE)
_metadatas: Union[List, Generator] = metadatas or ({} for _ in texts)
# Check to see if metadata includes ids
if metadatas is not None and (
metadatas[0].get("_id") or metadatas[0].get("id")
):
logger.warning(
"_id or id key found in metadata. "
"Please pop from each dict and input as separate list."
"Retrieving methods will include the same id as '_id' in metadata."
)
texts_batch = texts
_metadatas: Union[List, Generator] = metadatas or ({} for _ in texts)
metadatas_batch = _metadatas
result_ids = []
batch_size = kwargs.get("batch_size", DEFAULT_INSERT_BATCH_SIZE)
if batch_size:
texts_batch = []
metadatas_batch = []
size = 0
for i, (text, metadata) in enumerate(zip(texts, _metadatas)):
i = 0
for j, (text, metadata) in enumerate(zip(texts, _metadatas)):
size += len(text) + len(metadata)
texts_batch.append(text)
metadatas_batch.append(metadata)
if (i + 1) % batch_size == 0 or size >= 47_000_000:
result_ids.extend(self._insert_texts(texts_batch, metadatas_batch))
if (j + 1) % batch_size == 0 or size >= 47_000_000:
if ids:
batch_res = self.bulk_embed_and_insert_texts(
texts_batch, metadatas_batch, ids[i : j + 1]
)
else:
batch_res = self.bulk_embed_and_insert_texts(
texts_batch, metadatas_batch
)
result_ids.extend(batch_res)
texts_batch = []
metadatas_batch = []
size = 0
i = j + 1
if texts_batch:
result_ids.extend(self._insert_texts(texts_batch, metadatas_batch)) # type: ignore
return [str(id) for id in result_ids]
if ids:
batch_res = self.bulk_embed_and_insert_texts(
texts_batch, metadatas_batch, ids[i : j + 1]
)
else:
batch_res = self.bulk_embed_and_insert_texts(
texts_batch, metadatas_batch
)
result_ids.extend(batch_res)
return result_ids
def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> List:
def bulk_embed_and_insert_texts(
self,
texts: Union[List[str], Iterable[str]],
metadatas: Union[List[dict], Generator[dict, Any, Any]],
ids: Optional[List[str]] = None,
) -> List[str]:
"""Bulk insert single batch of texts, embeddings, and optionally ids.
See add_texts for additional details.
"""
if not texts:
return []
# Embed and create the documents
embeddings = self._embedding.embed_documents(texts)
to_insert = [
{self._text_key: t, self._embedding_key: embedding, **m}
for t, m, embedding in zip(texts, metadatas, embeddings)
]
# Compute embedding vectors
embeddings = self._embedding.embed_documents(texts) # type: ignore
if ids:
to_insert = [
{
"_id": str_to_oid(i),
self._text_key: t,
self._embedding_key: embedding,
**m,
}
for i, t, m, embedding in zip(ids, texts, metadatas, embeddings)
]
else:
to_insert = [
{self._text_key: t, self._embedding_key: embedding, **m}
for t, m, embedding in zip(texts, metadatas, embeddings)
]
# insert the documents in MongoDB Atlas
insert_result = self._collection.insert_many(to_insert) # type: ignore
return insert_result.inserted_ids
return [oid_to_str(_id) for _id in insert_result.inserted_ids]
def add_documents(
self,
documents: List[Document],
ids: Optional[List[str]] = None,
batch_size: int = DEFAULT_INSERT_BATCH_SIZE,
**kwargs: Any,
) -> List[str]:
"""Add documents to the vectorstore.
Args:
documents: Documents to add to the vectorstore.
ids: Optional list of unique ids that will be used as index in VectorStore.
See note on ids in add_texts.
batch_size: Number of documents to insert at a time.
Tuning this may help with performance and sidestep MongoDB limits.
Returns:
List of IDs of the added texts.
"""
n_docs = len(documents)
if ids:
assert len(ids) == n_docs, "Number of ids must equal number of documents."
result_ids = []
start = 0
for end in range(batch_size, n_docs + batch_size, batch_size):
texts, metadatas = zip(
*[(doc.page_content, doc.metadata) for doc in documents[start:end]]
)
if ids:
result_ids.extend(
self.bulk_embed_and_insert_texts(
texts=texts, metadatas=metadatas, ids=ids[start:end]
)
)
else:
result_ids.extend(
self.bulk_embed_and_insert_texts(texts=texts, metadatas=metadatas)
)
start = end
return result_ids
def _similarity_search_with_score(
self,
@ -196,8 +302,10 @@ class MongoDBAtlasVectorSearch(VectorStore):
pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
include_embedding: bool = False,
include_ids: bool = False,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Core implementation."""
params = {
"queryVector": embedding,
"path": self._embedding_key,
@ -223,22 +331,10 @@ class MongoDBAtlasVectorSearch(VectorStore):
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
docs = []
def _make_serializable(obj: Dict[str, Any]) -> None:
for k, v in obj.items():
if isinstance(v, dict):
_make_serializable(v)
elif isinstance(v, list) and v and isinstance(v[0], ObjectId):
obj[k] = [json_util.default(item) for item in v]
elif isinstance(v, ObjectId):
obj[k] = json_util.default(v)
for res in cursor:
text = res.pop(self._text_key)
score = res.pop("score")
# Make every ObjectId found JSON-Serializable
# following format used in bson.json_util.loads
# e.g. loads('{"_id": {"$oid": "664..."}}') == {'_id': ObjectId('664..')} # noqa: E501
_make_serializable(res)
make_serializable(res)
docs.append((Document(page_content=text, metadata=res), score))
return docs
@ -363,6 +459,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
embedding: Embeddings,
metadatas: Optional[List[Dict]] = None,
collection: Optional[Collection[MongoDBDocumentType]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> MongoDBAtlasVectorSearch:
"""Construct a `MongoDB Atlas Vector Search` vector store from raw documents.
@ -394,25 +491,25 @@ class MongoDBAtlasVectorSearch(VectorStore):
if collection is None:
raise ValueError("Must provide 'collection' named parameter.")
vectorstore = cls(collection, embedding, **kwargs)
vectorstore.add_texts(texts, metadatas=metadatas)
vectorstore.add_texts(texts=texts, metadatas=metadatas, ids=ids, **kwargs)
return vectorstore
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete by ObjectId or other criteria.
"""Delete documents from VectorStore by ids.
Args:
ids: List of ids to delete.
**kwargs: Other keyword arguments that subclasses might use.
**kwargs: Other keyword arguments passed to Collection.delete_many()
Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
search_params: dict[str, Any] = {}
filter = {}
if ids:
search_params["_id"] = {"$in": [ObjectId(id) for id in ids]}
return self._collection.delete_many({**search_params, **kwargs}).acknowledged
oids = [str_to_oid(i) for i in ids]
filter = {"_id": {"$in": oids}}
return self._collection.delete_many(filter=filter, **kwargs).acknowledged
async def adelete(
self, ids: Optional[List[str]] = None, **kwargs: Any

View File

@ -4,9 +4,10 @@ from __future__ import annotations
import os
from time import monotonic, sleep
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
import pytest # type: ignore[import-not-found]
from bson import ObjectId
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from pymongo import MongoClient
@ -15,7 +16,9 @@ from pymongo.errors import OperationFailure
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.index import drop_vector_search_index
from tests.utils import ConsistentFakeEmbeddings
from langchain_mongodb.utils import oid_to_str
from ..utils import ConsistentFakeEmbeddings
INDEX_NAME = "langchain-test-index-vectorstores"
INDEX_CREATION_NAME = "langchain-test-index-vectorstores-create-test"
@ -25,20 +28,25 @@ DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
INDEX_COLLECTION_NAME = "langchain_test_vectorstores_index"
INDEX_DB_NAME = "langchain_test_index_db"
DIMENSIONS = 1536
TIMEOUT = 10.0
TIMEOUT = 120.0
INTERVAL = 0.5
class PatchedMongoDBAtlasVectorSearch(MongoDBAtlasVectorSearch):
def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> List:
def bulk_embed_and_insert_texts(
self,
texts: Union[List[str], Iterable[str]],
metadatas: Union[List[dict], Generator[dict, Any, Any]],
ids: Optional[List[str]] = None,
) -> List:
"""Patched insert_texts that waits for data to be indexed before returning"""
ids = super()._insert_texts(texts, metadatas)
ids_inserted = super().bulk_embed_and_insert_texts(texts, metadatas, ids)
start = monotonic()
while len(ids) != self.similarity_search("sandwich") and (
while len(ids_inserted) != len(self.similarity_search("sandwich")) and (
monotonic() - start <= TIMEOUT
):
sleep(INTERVAL)
return ids
return ids_inserted
def create_vector_search_index(
self,
@ -87,6 +95,16 @@ def collection() -> Collection:
return get_collection()
@pytest.fixture
def texts() -> List[str]:
return [
"Dogs are tough.",
"Cats have fluff.",
"What is a sandwich?",
"That fence is purple.",
]
@pytest.fixture()
def index_collection() -> Collection:
return get_collection(INDEX_DB_NAME, INDEX_COLLECTION_NAME)
@ -118,12 +136,10 @@ class TestMongoDBAtlasVectorSearch:
)
@pytest.fixture
def embedding_openai(self) -> Embeddings:
def embeddings(self) -> Embeddings:
return ConsistentFakeEmbeddings(DIMENSIONS)
def test_from_documents(
self, embedding_openai: Embeddings, collection: Any
) -> None:
def test_from_documents(self, embeddings: Embeddings, collection: Any) -> None:
"""Test end to end construction and search."""
documents = [
Document(page_content="Dogs are tough.", metadata={"a": 1}),
@ -133,7 +149,7 @@ class TestMongoDBAtlasVectorSearch:
]
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
documents,
embedding_openai,
embeddings,
collection=collection,
index_name=INDEX_NAME,
)
@ -143,7 +159,7 @@ class TestMongoDBAtlasVectorSearch:
assert any([key.page_content == output[0].page_content for key in documents])
def test_from_documents_no_embedding_return(
self, embedding_openai: Embeddings, collection: Any
self, embeddings: Embeddings, collection: Any
) -> None:
"""Test end to end construction and search."""
documents = [
@ -154,7 +170,7 @@ class TestMongoDBAtlasVectorSearch:
]
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
documents,
embedding_openai,
embeddings,
collection=collection,
index_name=INDEX_NAME,
)
@ -166,7 +182,7 @@ class TestMongoDBAtlasVectorSearch:
assert any([key.page_content == output[0].page_content for key in documents])
def test_from_documents_embedding_return(
self, embedding_openai: Embeddings, collection: Any
self, embeddings: Embeddings, collection: Any
) -> None:
"""Test end to end construction and search."""
documents = [
@ -177,7 +193,7 @@ class TestMongoDBAtlasVectorSearch:
]
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
documents,
embedding_openai,
embeddings,
collection=collection,
index_name=INDEX_NAME,
)
@ -188,16 +204,12 @@ class TestMongoDBAtlasVectorSearch:
# Check for the presence of the metadata key
assert any([key.page_content == output[0].page_content for key in documents])
def test_from_texts(self, embedding_openai: Embeddings, collection: Any) -> None:
texts = [
"Dogs are tough.",
"Cats have fluff.",
"What is a sandwich?",
"That fence is purple.",
]
def test_from_texts(
self, embeddings: Embeddings, collection: Collection, texts: List[str]
) -> None:
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
texts,
embedding_openai,
embeddings,
collection=collection,
index_name=INDEX_NAME,
)
@ -205,19 +217,16 @@ class TestMongoDBAtlasVectorSearch:
assert len(output) == 1
def test_from_texts_with_metadatas(
self, embedding_openai: Embeddings, collection: Any
self,
embeddings: Embeddings,
collection: Collection,
texts: List[str],
) -> None:
texts = [
"Dogs are tough.",
"Cats have fluff.",
"What is a sandwich?",
"The fence is purple.",
]
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
metakeys = ["a", "b", "c", "d", "e"]
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
texts,
embedding_openai,
embeddings,
metadatas=metadatas,
collection=collection,
index_name=INDEX_NAME,
@ -228,18 +237,12 @@ class TestMongoDBAtlasVectorSearch:
assert any([key in output[0].metadata for key in metakeys])
def test_from_texts_with_metadatas_and_pre_filter(
self, embedding_openai: Embeddings, collection: Any
self, embeddings: Embeddings, collection: Any, texts: List[str]
) -> None:
texts = [
"Dogs are tough.",
"Cats have fluff.",
"What is a sandwich?",
"The fence is purple.",
]
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
texts,
embedding_openai,
embeddings,
metadatas=metadatas,
collection=collection,
index_name=INDEX_NAME,
@ -249,11 +252,11 @@ class TestMongoDBAtlasVectorSearch:
)
assert output == []
def test_mmr(self, embedding_openai: Embeddings, collection: Any) -> None:
def test_mmr(self, embeddings: Embeddings, collection: Any) -> None:
texts = ["foo", "foo", "fou", "foy"]
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
texts,
embedding_openai,
embeddings,
collection=collection,
index_name=INDEX_NAME,
)
@ -263,19 +266,175 @@ class TestMongoDBAtlasVectorSearch:
assert output[0].page_content == "foo"
assert output[1].page_content != "foo"
def test_delete(
self, embeddings: Embeddings, collection: Any, texts: List[str]
) -> None:
vectorstore = MongoDBAtlasVectorSearch( # PatchedMongoDBAtlasVectorSearch(
collection=collection,
embedding=embeddings,
index_name=INDEX_NAME,
)
clxn: Collection = vectorstore._collection
assert clxn.count_documents({}) == 0
ids = vectorstore.add_texts(texts)
assert clxn.count_documents({}) == len(texts)
deleted = vectorstore.delete(ids[-2:])
assert deleted
assert clxn.count_documents({}) == len(texts) - 2
new_ids = vectorstore.add_texts(["Pigs eat stuff", "Pigs eat sandwiches"])
assert set(new_ids).intersection(set(ids)) == set() # new ids will be unique.
assert isinstance(new_ids, list)
assert all(isinstance(i, str) for i in new_ids)
assert len(new_ids) == 2
assert clxn.count_documents({}) == 4
def test_add_texts(
self,
embeddings: Embeddings,
collection: Collection,
texts: List[str],
) -> None:
"""Tests API of add_texts, focussing on id treatment"""
metadatas: List[Dict[str, Any]] = [
{"a": 1},
{"b": 1},
{"c": 1},
{"d": 1, "e": 2},
]
vectorstore = PatchedMongoDBAtlasVectorSearch(
collection=collection, embedding=embeddings, index_name=INDEX_NAME
)
# Case 1. Add texts without ids
provided_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas)
all_docs = list(vectorstore._collection.find({}))
assert all("_id" in doc for doc in all_docs)
docids = set(doc["_id"] for doc in all_docs)
assert all(isinstance(_id, ObjectId) for _id in docids) #
assert set(provided_ids) == set(oid_to_str(oid) for oid in docids)
# Case 2: Test Document.metadata looks right. i.e. contains _id
search_res = vectorstore.similarity_search_with_score("sandwich", k=1)
doc, score = search_res[0]
assert "_id" in doc.metadata
# Case 3: Add new ids that are 24-char hex strings
hex_ids = [oid_to_str(ObjectId()) for _ in range(2)]
hex_texts = ["Text for hex_id"] * len(hex_ids)
out_ids = vectorstore.add_texts(texts=hex_texts, ids=hex_ids)
assert set(out_ids) == set(hex_ids)
assert collection.count_documents({}) == len(texts) + len(hex_texts)
assert all(
isinstance(doc["_id"], ObjectId) for doc in vectorstore._collection.find({})
)
# Case 4: Add new ids that cannot be cast to ObjectId
# - We can still index and search on them
str_ids = ["Sandwiches are beautiful,", "..sandwiches are fine."]
str_texts = str_ids # No reason for them to differ
out_ids = vectorstore.add_texts(texts=str_texts, ids=str_ids)
assert set(out_ids) == set(str_ids)
assert collection.count_documents({}) == 8
res = vectorstore.similarity_search("sandwich", k=8)
assert any(str_ids[0] in doc.metadata["_id"] for doc in res)
# Case 5: Test adding in multiple batches
batch_size = 2
batch_ids = [oid_to_str(ObjectId()) for _ in range(2 * batch_size)]
batch_texts = [f"Text for batch text {i}" for i in range(2 * batch_size)]
out_ids = vectorstore.add_texts(
texts=batch_texts, ids=batch_ids, batch_size=batch_size
)
assert set(out_ids) == set(batch_ids)
assert collection.count_documents({}) == 12
# Case 6: _ids in metadata
collection.delete_many({})
# 6a. Unique _id in metadata, but ids=None
# Will be added as if ids kwarg provided
i = 0
n = len(texts)
assert len(metadatas) == n
_ids = [str(i) for i in range(n)]
for md in metadatas:
md["_id"] = _ids[i]
i += 1
returned_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas)
assert returned_ids == ["0", "1", "2", "3"]
assert set(d["_id"] for d in vectorstore._collection.find({})) == set(_ids)
# 6b. Unique "id", not "_id", but ids=None
# New ids will be assigned
i = 1
for md in metadatas:
md.pop("_id")
md["id"] = f"{1}"
i += 1
returned_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas)
assert len(set(returned_ids).intersection(set(_ids))) == 0
def test_add_documents(
self,
embeddings: Embeddings,
collection: Collection,
index_name: str = INDEX_NAME,
) -> None:
"""Tests add_documents."""
vectorstore = PatchedMongoDBAtlasVectorSearch(
collection=collection, embedding=embeddings, index_name=INDEX_NAME
)
# Case 1: No ids
n_docs = 10
batch_size = 3
docs = [
Document(page_content=f"document {i}", metadata={"i": i})
for i in range(n_docs)
]
result_ids = vectorstore.add_documents(docs, batch_size=batch_size)
assert len(result_ids) == n_docs
assert collection.count_documents({}) == n_docs
# Case 2: ids
collection.delete_many({})
n_docs = 10
batch_size = 3
docs = [
Document(page_content=f"document {i}", metadata={"i": i})
for i in range(n_docs)
]
ids = [str(i) for i in range(n_docs)]
result_ids = vectorstore.add_documents(docs, ids, batch_size=batch_size)
assert len(result_ids) == n_docs
assert set(ids) == set(collection.distinct("_id"))
# Case 3: Single batch
collection.delete_many({})
n_docs = 3
batch_size = 10
docs = [
Document(page_content=f"document {i}", metadata={"i": i})
for i in range(n_docs)
]
ids = [str(i) for i in range(n_docs)]
result_ids = vectorstore.add_documents(docs, ids, batch_size=batch_size)
assert len(result_ids) == n_docs
assert set(ids) == set(collection.distinct("_id"))
def test_index_creation(
self, embedding_openai: Embeddings, index_collection: Any
self, embeddings: Embeddings, index_collection: Any
) -> None:
vectorstore = PatchedMongoDBAtlasVectorSearch(
index_collection, embedding_openai, index_name=INDEX_CREATION_NAME
index_collection, embeddings, index_name=INDEX_CREATION_NAME
)
vectorstore.create_vector_search_index(dimensions=1536)
def test_index_update(
self, embedding_openai: Embeddings, index_collection: Any
) -> None:
def test_index_update(self, embeddings: Embeddings, index_collection: Any) -> None:
vectorstore = PatchedMongoDBAtlasVectorSearch(
index_collection, embedding_openai, index_name=INDEX_CREATION_NAME
index_collection, embeddings, index_name=INDEX_CREATION_NAME
)
vectorstore.create_vector_search_index(dimensions=1536)
vectorstore.create_vector_search_index(dimensions=1536, update=True)

View File

@ -8,6 +8,7 @@ from langchain_core.embeddings import Embeddings
from pymongo.collection import Collection
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.utils import str_to_oid
from tests.utils import ConsistentFakeEmbeddings, MockCollection
INDEX_NAME = "langchain-test-index"
@ -81,7 +82,7 @@ class TestMongoDBAtlasVectorSearch:
assert loads(dumps(output[0].page_content)) == output[0].page_content
assert loads(dumps(output[0].metadata)) == output[0].metadata
json_metadata = dumps(output[0].metadata) # normal json.dumps
assert isinstance(json_util.loads(json_metadata)["_id"], ObjectId)
isinstance(str_to_oid(json_util.loads(json_metadata)["_id"]), ObjectId)
def test_from_documents(
self, embedding_openai: Embeddings, collection: MockCollection