mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
partners/mongodb : Significant MongoDBVectorSearch ID enhancements (#23535)
## Description This pull-request improves the treatment of document IDs in `MongoDBAtlasVectorSearch`. Class method signatures of add_documents, add_texts, delete, and from_texts now include an `ids:Optional[List[str]]` keyword argument permitting the user greater control. Note that, as before, IDs may also be inferred from `Document.metadata['_id']` if present, but this is no longer required, IDs can also optionally be returned from searches. This PR closes the following JIRA issues. * [PYTHON-4446](https://jira.mongodb.org/browse/PYTHON-4446) MongoDBVectorSearch delete / add_texts function rework * [PYTHON-4435](https://jira.mongodb.org/browse/PYTHON-4435) Add support for "Indexing" * [PYTHON-4534](https://jira.mongodb.org/browse/PYTHON-4534) Ensure datetimes are json-serializable --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
cc2cbfabfc
commit
a47f69a120
@ -8,10 +8,15 @@ are duplicated in this utility respectively from modules:
|
||||
- "libs/community/langchain_community/utils/math.py"
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import List, Union
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
from bson import ObjectId
|
||||
from bson.errors import InvalidId
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -88,3 +93,55 @@ def maximal_marginal_relevance(
|
||||
idxs.append(idx_to_add)
|
||||
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
|
||||
return idxs
|
||||
|
||||
|
||||
def str_to_oid(str_repr: str) -> ObjectId | str:
|
||||
"""Attempt to cast string representation of id to MongoDB's internal BSON ObjectId.
|
||||
|
||||
To be consistent with ObjectId, input must be a 24 character hex string.
|
||||
If it is not, MongoDB will happily use the string in the main _id index.
|
||||
Importantly, the str representation that comes out of MongoDB will have this form.
|
||||
|
||||
Args:
|
||||
str_repr: id as string.
|
||||
|
||||
Returns:
|
||||
ObjectID
|
||||
"""
|
||||
try:
|
||||
return ObjectId(str_repr)
|
||||
except InvalidId:
|
||||
logger.debug(
|
||||
"ObjectIds must be 12-character byte or 24-character hex strings. "
|
||||
"Examples: b'heres12bytes', '6f6e6568656c6c6f68656768'"
|
||||
)
|
||||
return str_repr
|
||||
|
||||
|
||||
def oid_to_str(oid: ObjectId) -> str:
|
||||
"""Convert MongoDB's internal BSON ObjectId into a simple str for compatibility.
|
||||
|
||||
Instructive helper to show where data is coming out of MongoDB.
|
||||
|
||||
Args:
|
||||
oid: bson.ObjectId
|
||||
|
||||
Returns:
|
||||
24 character hex string.
|
||||
"""
|
||||
return str(oid)
|
||||
|
||||
|
||||
def make_serializable(
|
||||
obj: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Recursively cast values in a dict to a form able to json.dump"""
|
||||
for k, v in obj.items():
|
||||
if isinstance(v, dict):
|
||||
make_serializable(v)
|
||||
elif isinstance(v, list) and v and isinstance(v[0], (ObjectId, date, datetime)):
|
||||
obj[k] = [oid_to_str(item) for item in v]
|
||||
elif isinstance(v, ObjectId):
|
||||
obj[k] = oid_to_str(v)
|
||||
elif isinstance(v, (datetime, date)):
|
||||
obj[k] = v.isoformat()
|
||||
|
@ -16,7 +16,6 @@ from typing import (
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from bson import ObjectId, json_util
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
@ -30,7 +29,12 @@ from langchain_mongodb.index import (
|
||||
create_vector_search_index,
|
||||
update_vector_search_index,
|
||||
)
|
||||
from langchain_mongodb.utils import maximal_marginal_relevance
|
||||
from langchain_mongodb.utils import (
|
||||
make_serializable,
|
||||
maximal_marginal_relevance,
|
||||
oid_to_str,
|
||||
str_to_oid,
|
||||
)
|
||||
|
||||
MongoDBDocumentType = TypeVar("MongoDBDocumentType", bound=Dict[str, Any])
|
||||
VST = TypeVar("VST", bound=VectorStore)
|
||||
@ -143,51 +147,153 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
"""Add texts, create embeddings, and add to the Collection and index.
|
||||
|
||||
Important notes on ids:
|
||||
- If _id or id is a key in the metadatas dicts, one must
|
||||
pop them and provide as separate list.
|
||||
- They must be unique.
|
||||
- If they are not provided, the VectorStore will create unique ones,
|
||||
stored as bson.ObjectIds internally, and strings in Langchain.
|
||||
These will appear in Document.metadata with key, '_id'.
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
ids: Optional list of unique ids that will be used as index in VectorStore.
|
||||
See note on ids.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
List of ids added to the vectorstore.
|
||||
"""
|
||||
batch_size = kwargs.get("batch_size", DEFAULT_INSERT_BATCH_SIZE)
|
||||
_metadatas: Union[List, Generator] = metadatas or ({} for _ in texts)
|
||||
|
||||
# Check to see if metadata includes ids
|
||||
if metadatas is not None and (
|
||||
metadatas[0].get("_id") or metadatas[0].get("id")
|
||||
):
|
||||
logger.warning(
|
||||
"_id or id key found in metadata. "
|
||||
"Please pop from each dict and input as separate list."
|
||||
"Retrieving methods will include the same id as '_id' in metadata."
|
||||
)
|
||||
|
||||
texts_batch = texts
|
||||
_metadatas: Union[List, Generator] = metadatas or ({} for _ in texts)
|
||||
metadatas_batch = _metadatas
|
||||
|
||||
result_ids = []
|
||||
batch_size = kwargs.get("batch_size", DEFAULT_INSERT_BATCH_SIZE)
|
||||
if batch_size:
|
||||
texts_batch = []
|
||||
metadatas_batch = []
|
||||
size = 0
|
||||
for i, (text, metadata) in enumerate(zip(texts, _metadatas)):
|
||||
i = 0
|
||||
for j, (text, metadata) in enumerate(zip(texts, _metadatas)):
|
||||
size += len(text) + len(metadata)
|
||||
texts_batch.append(text)
|
||||
metadatas_batch.append(metadata)
|
||||
if (i + 1) % batch_size == 0 or size >= 47_000_000:
|
||||
result_ids.extend(self._insert_texts(texts_batch, metadatas_batch))
|
||||
if (j + 1) % batch_size == 0 or size >= 47_000_000:
|
||||
if ids:
|
||||
batch_res = self.bulk_embed_and_insert_texts(
|
||||
texts_batch, metadatas_batch, ids[i : j + 1]
|
||||
)
|
||||
else:
|
||||
batch_res = self.bulk_embed_and_insert_texts(
|
||||
texts_batch, metadatas_batch
|
||||
)
|
||||
result_ids.extend(batch_res)
|
||||
texts_batch = []
|
||||
metadatas_batch = []
|
||||
size = 0
|
||||
i = j + 1
|
||||
if texts_batch:
|
||||
result_ids.extend(self._insert_texts(texts_batch, metadatas_batch)) # type: ignore
|
||||
return [str(id) for id in result_ids]
|
||||
if ids:
|
||||
batch_res = self.bulk_embed_and_insert_texts(
|
||||
texts_batch, metadatas_batch, ids[i : j + 1]
|
||||
)
|
||||
else:
|
||||
batch_res = self.bulk_embed_and_insert_texts(
|
||||
texts_batch, metadatas_batch
|
||||
)
|
||||
result_ids.extend(batch_res)
|
||||
return result_ids
|
||||
|
||||
def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> List:
|
||||
def bulk_embed_and_insert_texts(
|
||||
self,
|
||||
texts: Union[List[str], Iterable[str]],
|
||||
metadatas: Union[List[dict], Generator[dict, Any, Any]],
|
||||
ids: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
"""Bulk insert single batch of texts, embeddings, and optionally ids.
|
||||
|
||||
See add_texts for additional details.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
# Embed and create the documents
|
||||
embeddings = self._embedding.embed_documents(texts)
|
||||
to_insert = [
|
||||
{self._text_key: t, self._embedding_key: embedding, **m}
|
||||
for t, m, embedding in zip(texts, metadatas, embeddings)
|
||||
]
|
||||
# Compute embedding vectors
|
||||
embeddings = self._embedding.embed_documents(texts) # type: ignore
|
||||
if ids:
|
||||
to_insert = [
|
||||
{
|
||||
"_id": str_to_oid(i),
|
||||
self._text_key: t,
|
||||
self._embedding_key: embedding,
|
||||
**m,
|
||||
}
|
||||
for i, t, m, embedding in zip(ids, texts, metadatas, embeddings)
|
||||
]
|
||||
else:
|
||||
to_insert = [
|
||||
{self._text_key: t, self._embedding_key: embedding, **m}
|
||||
for t, m, embedding in zip(texts, metadatas, embeddings)
|
||||
]
|
||||
# insert the documents in MongoDB Atlas
|
||||
insert_result = self._collection.insert_many(to_insert) # type: ignore
|
||||
return insert_result.inserted_ids
|
||||
return [oid_to_str(_id) for _id in insert_result.inserted_ids]
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
documents: List[Document],
|
||||
ids: Optional[List[str]] = None,
|
||||
batch_size: int = DEFAULT_INSERT_BATCH_SIZE,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Add documents to the vectorstore.
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
ids: Optional list of unique ids that will be used as index in VectorStore.
|
||||
See note on ids in add_texts.
|
||||
batch_size: Number of documents to insert at a time.
|
||||
Tuning this may help with performance and sidestep MongoDB limits.
|
||||
|
||||
Returns:
|
||||
List of IDs of the added texts.
|
||||
"""
|
||||
n_docs = len(documents)
|
||||
if ids:
|
||||
assert len(ids) == n_docs, "Number of ids must equal number of documents."
|
||||
result_ids = []
|
||||
start = 0
|
||||
for end in range(batch_size, n_docs + batch_size, batch_size):
|
||||
texts, metadatas = zip(
|
||||
*[(doc.page_content, doc.metadata) for doc in documents[start:end]]
|
||||
)
|
||||
if ids:
|
||||
result_ids.extend(
|
||||
self.bulk_embed_and_insert_texts(
|
||||
texts=texts, metadatas=metadatas, ids=ids[start:end]
|
||||
)
|
||||
)
|
||||
else:
|
||||
result_ids.extend(
|
||||
self.bulk_embed_and_insert_texts(texts=texts, metadatas=metadatas)
|
||||
)
|
||||
start = end
|
||||
return result_ids
|
||||
|
||||
def _similarity_search_with_score(
|
||||
self,
|
||||
@ -196,8 +302,10 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
pre_filter: Optional[Dict] = None,
|
||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||
include_embedding: bool = False,
|
||||
include_ids: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Core implementation."""
|
||||
params = {
|
||||
"queryVector": embedding,
|
||||
"path": self._embedding_key,
|
||||
@ -223,22 +331,10 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
|
||||
docs = []
|
||||
|
||||
def _make_serializable(obj: Dict[str, Any]) -> None:
|
||||
for k, v in obj.items():
|
||||
if isinstance(v, dict):
|
||||
_make_serializable(v)
|
||||
elif isinstance(v, list) and v and isinstance(v[0], ObjectId):
|
||||
obj[k] = [json_util.default(item) for item in v]
|
||||
elif isinstance(v, ObjectId):
|
||||
obj[k] = json_util.default(v)
|
||||
|
||||
for res in cursor:
|
||||
text = res.pop(self._text_key)
|
||||
score = res.pop("score")
|
||||
# Make every ObjectId found JSON-Serializable
|
||||
# following format used in bson.json_util.loads
|
||||
# e.g. loads('{"_id": {"$oid": "664..."}}') == {'_id': ObjectId('664..')} # noqa: E501
|
||||
_make_serializable(res)
|
||||
make_serializable(res)
|
||||
docs.append((Document(page_content=text, metadata=res), score))
|
||||
return docs
|
||||
|
||||
@ -363,6 +459,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[Dict]] = None,
|
||||
collection: Optional[Collection[MongoDBDocumentType]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> MongoDBAtlasVectorSearch:
|
||||
"""Construct a `MongoDB Atlas Vector Search` vector store from raw documents.
|
||||
@ -394,25 +491,25 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
||||
if collection is None:
|
||||
raise ValueError("Must provide 'collection' named parameter.")
|
||||
vectorstore = cls(collection, embedding, **kwargs)
|
||||
vectorstore.add_texts(texts, metadatas=metadatas)
|
||||
vectorstore.add_texts(texts=texts, metadatas=metadatas, ids=ids, **kwargs)
|
||||
return vectorstore
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
||||
"""Delete by ObjectId or other criteria.
|
||||
"""Delete documents from VectorStore by ids.
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
**kwargs: Other keyword arguments that subclasses might use.
|
||||
**kwargs: Other keyword arguments passed to Collection.delete_many()
|
||||
|
||||
Returns:
|
||||
Optional[bool]: True if deletion is successful,
|
||||
False otherwise, None if not implemented.
|
||||
"""
|
||||
search_params: dict[str, Any] = {}
|
||||
filter = {}
|
||||
if ids:
|
||||
search_params["_id"] = {"$in": [ObjectId(id) for id in ids]}
|
||||
|
||||
return self._collection.delete_many({**search_params, **kwargs}).acknowledged
|
||||
oids = [str_to_oid(i) for i in ids]
|
||||
filter = {"_id": {"$in": oids}}
|
||||
return self._collection.delete_many(filter=filter, **kwargs).acknowledged
|
||||
|
||||
async def adelete(
|
||||
self, ids: Optional[List[str]] = None, **kwargs: Any
|
||||
|
@ -4,9 +4,10 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from time import monotonic, sleep
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
|
||||
|
||||
import pytest # type: ignore[import-not-found]
|
||||
from bson import ObjectId
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from pymongo import MongoClient
|
||||
@ -15,7 +16,9 @@ from pymongo.errors import OperationFailure
|
||||
|
||||
from langchain_mongodb import MongoDBAtlasVectorSearch
|
||||
from langchain_mongodb.index import drop_vector_search_index
|
||||
from tests.utils import ConsistentFakeEmbeddings
|
||||
from langchain_mongodb.utils import oid_to_str
|
||||
|
||||
from ..utils import ConsistentFakeEmbeddings
|
||||
|
||||
INDEX_NAME = "langchain-test-index-vectorstores"
|
||||
INDEX_CREATION_NAME = "langchain-test-index-vectorstores-create-test"
|
||||
@ -25,20 +28,25 @@ DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
|
||||
INDEX_COLLECTION_NAME = "langchain_test_vectorstores_index"
|
||||
INDEX_DB_NAME = "langchain_test_index_db"
|
||||
DIMENSIONS = 1536
|
||||
TIMEOUT = 10.0
|
||||
TIMEOUT = 120.0
|
||||
INTERVAL = 0.5
|
||||
|
||||
|
||||
class PatchedMongoDBAtlasVectorSearch(MongoDBAtlasVectorSearch):
|
||||
def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> List:
|
||||
def bulk_embed_and_insert_texts(
|
||||
self,
|
||||
texts: Union[List[str], Iterable[str]],
|
||||
metadatas: Union[List[dict], Generator[dict, Any, Any]],
|
||||
ids: Optional[List[str]] = None,
|
||||
) -> List:
|
||||
"""Patched insert_texts that waits for data to be indexed before returning"""
|
||||
ids = super()._insert_texts(texts, metadatas)
|
||||
ids_inserted = super().bulk_embed_and_insert_texts(texts, metadatas, ids)
|
||||
start = monotonic()
|
||||
while len(ids) != self.similarity_search("sandwich") and (
|
||||
while len(ids_inserted) != len(self.similarity_search("sandwich")) and (
|
||||
monotonic() - start <= TIMEOUT
|
||||
):
|
||||
sleep(INTERVAL)
|
||||
return ids
|
||||
return ids_inserted
|
||||
|
||||
def create_vector_search_index(
|
||||
self,
|
||||
@ -87,6 +95,16 @@ def collection() -> Collection:
|
||||
return get_collection()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def texts() -> List[str]:
|
||||
return [
|
||||
"Dogs are tough.",
|
||||
"Cats have fluff.",
|
||||
"What is a sandwich?",
|
||||
"That fence is purple.",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def index_collection() -> Collection:
|
||||
return get_collection(INDEX_DB_NAME, INDEX_COLLECTION_NAME)
|
||||
@ -118,12 +136,10 @@ class TestMongoDBAtlasVectorSearch:
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_openai(self) -> Embeddings:
|
||||
def embeddings(self) -> Embeddings:
|
||||
return ConsistentFakeEmbeddings(DIMENSIONS)
|
||||
|
||||
def test_from_documents(
|
||||
self, embedding_openai: Embeddings, collection: Any
|
||||
) -> None:
|
||||
def test_from_documents(self, embeddings: Embeddings, collection: Any) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
documents = [
|
||||
Document(page_content="Dogs are tough.", metadata={"a": 1}),
|
||||
@ -133,7 +149,7 @@ class TestMongoDBAtlasVectorSearch:
|
||||
]
|
||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
|
||||
documents,
|
||||
embedding_openai,
|
||||
embeddings,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
@ -143,7 +159,7 @@ class TestMongoDBAtlasVectorSearch:
|
||||
assert any([key.page_content == output[0].page_content for key in documents])
|
||||
|
||||
def test_from_documents_no_embedding_return(
|
||||
self, embedding_openai: Embeddings, collection: Any
|
||||
self, embeddings: Embeddings, collection: Any
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
documents = [
|
||||
@ -154,7 +170,7 @@ class TestMongoDBAtlasVectorSearch:
|
||||
]
|
||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
|
||||
documents,
|
||||
embedding_openai,
|
||||
embeddings,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
@ -166,7 +182,7 @@ class TestMongoDBAtlasVectorSearch:
|
||||
assert any([key.page_content == output[0].page_content for key in documents])
|
||||
|
||||
def test_from_documents_embedding_return(
|
||||
self, embedding_openai: Embeddings, collection: Any
|
||||
self, embeddings: Embeddings, collection: Any
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
documents = [
|
||||
@ -177,7 +193,7 @@ class TestMongoDBAtlasVectorSearch:
|
||||
]
|
||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
|
||||
documents,
|
||||
embedding_openai,
|
||||
embeddings,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
@ -188,16 +204,12 @@ class TestMongoDBAtlasVectorSearch:
|
||||
# Check for the presence of the metadata key
|
||||
assert any([key.page_content == output[0].page_content for key in documents])
|
||||
|
||||
def test_from_texts(self, embedding_openai: Embeddings, collection: Any) -> None:
|
||||
texts = [
|
||||
"Dogs are tough.",
|
||||
"Cats have fluff.",
|
||||
"What is a sandwich?",
|
||||
"That fence is purple.",
|
||||
]
|
||||
def test_from_texts(
|
||||
self, embeddings: Embeddings, collection: Collection, texts: List[str]
|
||||
) -> None:
|
||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
embeddings,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
@ -205,19 +217,16 @@ class TestMongoDBAtlasVectorSearch:
|
||||
assert len(output) == 1
|
||||
|
||||
def test_from_texts_with_metadatas(
|
||||
self, embedding_openai: Embeddings, collection: Any
|
||||
self,
|
||||
embeddings: Embeddings,
|
||||
collection: Collection,
|
||||
texts: List[str],
|
||||
) -> None:
|
||||
texts = [
|
||||
"Dogs are tough.",
|
||||
"Cats have fluff.",
|
||||
"What is a sandwich?",
|
||||
"The fence is purple.",
|
||||
]
|
||||
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
||||
metakeys = ["a", "b", "c", "d", "e"]
|
||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
embeddings,
|
||||
metadatas=metadatas,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
@ -228,18 +237,12 @@ class TestMongoDBAtlasVectorSearch:
|
||||
assert any([key in output[0].metadata for key in metakeys])
|
||||
|
||||
def test_from_texts_with_metadatas_and_pre_filter(
|
||||
self, embedding_openai: Embeddings, collection: Any
|
||||
self, embeddings: Embeddings, collection: Any, texts: List[str]
|
||||
) -> None:
|
||||
texts = [
|
||||
"Dogs are tough.",
|
||||
"Cats have fluff.",
|
||||
"What is a sandwich?",
|
||||
"The fence is purple.",
|
||||
]
|
||||
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
embeddings,
|
||||
metadatas=metadatas,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
@ -249,11 +252,11 @@ class TestMongoDBAtlasVectorSearch:
|
||||
)
|
||||
assert output == []
|
||||
|
||||
def test_mmr(self, embedding_openai: Embeddings, collection: Any) -> None:
|
||||
def test_mmr(self, embeddings: Embeddings, collection: Any) -> None:
|
||||
texts = ["foo", "foo", "fou", "foy"]
|
||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
embeddings,
|
||||
collection=collection,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
@ -263,19 +266,175 @@ class TestMongoDBAtlasVectorSearch:
|
||||
assert output[0].page_content == "foo"
|
||||
assert output[1].page_content != "foo"
|
||||
|
||||
def test_delete(
|
||||
self, embeddings: Embeddings, collection: Any, texts: List[str]
|
||||
) -> None:
|
||||
vectorstore = MongoDBAtlasVectorSearch( # PatchedMongoDBAtlasVectorSearch(
|
||||
collection=collection,
|
||||
embedding=embeddings,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
clxn: Collection = vectorstore._collection
|
||||
assert clxn.count_documents({}) == 0
|
||||
ids = vectorstore.add_texts(texts)
|
||||
assert clxn.count_documents({}) == len(texts)
|
||||
|
||||
deleted = vectorstore.delete(ids[-2:])
|
||||
assert deleted
|
||||
assert clxn.count_documents({}) == len(texts) - 2
|
||||
|
||||
new_ids = vectorstore.add_texts(["Pigs eat stuff", "Pigs eat sandwiches"])
|
||||
assert set(new_ids).intersection(set(ids)) == set() # new ids will be unique.
|
||||
assert isinstance(new_ids, list)
|
||||
assert all(isinstance(i, str) for i in new_ids)
|
||||
assert len(new_ids) == 2
|
||||
assert clxn.count_documents({}) == 4
|
||||
|
||||
def test_add_texts(
|
||||
self,
|
||||
embeddings: Embeddings,
|
||||
collection: Collection,
|
||||
texts: List[str],
|
||||
) -> None:
|
||||
"""Tests API of add_texts, focussing on id treatment"""
|
||||
metadatas: List[Dict[str, Any]] = [
|
||||
{"a": 1},
|
||||
{"b": 1},
|
||||
{"c": 1},
|
||||
{"d": 1, "e": 2},
|
||||
]
|
||||
|
||||
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
||||
collection=collection, embedding=embeddings, index_name=INDEX_NAME
|
||||
)
|
||||
|
||||
# Case 1. Add texts without ids
|
||||
provided_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas)
|
||||
all_docs = list(vectorstore._collection.find({}))
|
||||
assert all("_id" in doc for doc in all_docs)
|
||||
docids = set(doc["_id"] for doc in all_docs)
|
||||
assert all(isinstance(_id, ObjectId) for _id in docids) #
|
||||
assert set(provided_ids) == set(oid_to_str(oid) for oid in docids)
|
||||
|
||||
# Case 2: Test Document.metadata looks right. i.e. contains _id
|
||||
search_res = vectorstore.similarity_search_with_score("sandwich", k=1)
|
||||
doc, score = search_res[0]
|
||||
assert "_id" in doc.metadata
|
||||
|
||||
# Case 3: Add new ids that are 24-char hex strings
|
||||
hex_ids = [oid_to_str(ObjectId()) for _ in range(2)]
|
||||
hex_texts = ["Text for hex_id"] * len(hex_ids)
|
||||
out_ids = vectorstore.add_texts(texts=hex_texts, ids=hex_ids)
|
||||
assert set(out_ids) == set(hex_ids)
|
||||
assert collection.count_documents({}) == len(texts) + len(hex_texts)
|
||||
assert all(
|
||||
isinstance(doc["_id"], ObjectId) for doc in vectorstore._collection.find({})
|
||||
)
|
||||
|
||||
# Case 4: Add new ids that cannot be cast to ObjectId
|
||||
# - We can still index and search on them
|
||||
str_ids = ["Sandwiches are beautiful,", "..sandwiches are fine."]
|
||||
str_texts = str_ids # No reason for them to differ
|
||||
out_ids = vectorstore.add_texts(texts=str_texts, ids=str_ids)
|
||||
assert set(out_ids) == set(str_ids)
|
||||
assert collection.count_documents({}) == 8
|
||||
res = vectorstore.similarity_search("sandwich", k=8)
|
||||
assert any(str_ids[0] in doc.metadata["_id"] for doc in res)
|
||||
|
||||
# Case 5: Test adding in multiple batches
|
||||
batch_size = 2
|
||||
batch_ids = [oid_to_str(ObjectId()) for _ in range(2 * batch_size)]
|
||||
batch_texts = [f"Text for batch text {i}" for i in range(2 * batch_size)]
|
||||
out_ids = vectorstore.add_texts(
|
||||
texts=batch_texts, ids=batch_ids, batch_size=batch_size
|
||||
)
|
||||
assert set(out_ids) == set(batch_ids)
|
||||
assert collection.count_documents({}) == 12
|
||||
|
||||
# Case 6: _ids in metadata
|
||||
collection.delete_many({})
|
||||
# 6a. Unique _id in metadata, but ids=None
|
||||
# Will be added as if ids kwarg provided
|
||||
i = 0
|
||||
n = len(texts)
|
||||
assert len(metadatas) == n
|
||||
_ids = [str(i) for i in range(n)]
|
||||
for md in metadatas:
|
||||
md["_id"] = _ids[i]
|
||||
i += 1
|
||||
returned_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas)
|
||||
assert returned_ids == ["0", "1", "2", "3"]
|
||||
assert set(d["_id"] for d in vectorstore._collection.find({})) == set(_ids)
|
||||
|
||||
# 6b. Unique "id", not "_id", but ids=None
|
||||
# New ids will be assigned
|
||||
i = 1
|
||||
for md in metadatas:
|
||||
md.pop("_id")
|
||||
md["id"] = f"{1}"
|
||||
i += 1
|
||||
returned_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas)
|
||||
assert len(set(returned_ids).intersection(set(_ids))) == 0
|
||||
|
||||
def test_add_documents(
|
||||
self,
|
||||
embeddings: Embeddings,
|
||||
collection: Collection,
|
||||
index_name: str = INDEX_NAME,
|
||||
) -> None:
|
||||
"""Tests add_documents."""
|
||||
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
||||
collection=collection, embedding=embeddings, index_name=INDEX_NAME
|
||||
)
|
||||
|
||||
# Case 1: No ids
|
||||
n_docs = 10
|
||||
batch_size = 3
|
||||
docs = [
|
||||
Document(page_content=f"document {i}", metadata={"i": i})
|
||||
for i in range(n_docs)
|
||||
]
|
||||
result_ids = vectorstore.add_documents(docs, batch_size=batch_size)
|
||||
assert len(result_ids) == n_docs
|
||||
assert collection.count_documents({}) == n_docs
|
||||
|
||||
# Case 2: ids
|
||||
collection.delete_many({})
|
||||
n_docs = 10
|
||||
batch_size = 3
|
||||
docs = [
|
||||
Document(page_content=f"document {i}", metadata={"i": i})
|
||||
for i in range(n_docs)
|
||||
]
|
||||
ids = [str(i) for i in range(n_docs)]
|
||||
result_ids = vectorstore.add_documents(docs, ids, batch_size=batch_size)
|
||||
assert len(result_ids) == n_docs
|
||||
assert set(ids) == set(collection.distinct("_id"))
|
||||
|
||||
# Case 3: Single batch
|
||||
collection.delete_many({})
|
||||
n_docs = 3
|
||||
batch_size = 10
|
||||
docs = [
|
||||
Document(page_content=f"document {i}", metadata={"i": i})
|
||||
for i in range(n_docs)
|
||||
]
|
||||
ids = [str(i) for i in range(n_docs)]
|
||||
result_ids = vectorstore.add_documents(docs, ids, batch_size=batch_size)
|
||||
assert len(result_ids) == n_docs
|
||||
assert set(ids) == set(collection.distinct("_id"))
|
||||
|
||||
def test_index_creation(
|
||||
self, embedding_openai: Embeddings, index_collection: Any
|
||||
self, embeddings: Embeddings, index_collection: Any
|
||||
) -> None:
|
||||
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
||||
index_collection, embedding_openai, index_name=INDEX_CREATION_NAME
|
||||
index_collection, embeddings, index_name=INDEX_CREATION_NAME
|
||||
)
|
||||
vectorstore.create_vector_search_index(dimensions=1536)
|
||||
|
||||
def test_index_update(
|
||||
self, embedding_openai: Embeddings, index_collection: Any
|
||||
) -> None:
|
||||
def test_index_update(self, embeddings: Embeddings, index_collection: Any) -> None:
|
||||
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
||||
index_collection, embedding_openai, index_name=INDEX_CREATION_NAME
|
||||
index_collection, embeddings, index_name=INDEX_CREATION_NAME
|
||||
)
|
||||
vectorstore.create_vector_search_index(dimensions=1536)
|
||||
vectorstore.create_vector_search_index(dimensions=1536, update=True)
|
||||
|
@ -8,6 +8,7 @@ from langchain_core.embeddings import Embeddings
|
||||
from pymongo.collection import Collection
|
||||
|
||||
from langchain_mongodb import MongoDBAtlasVectorSearch
|
||||
from langchain_mongodb.utils import str_to_oid
|
||||
from tests.utils import ConsistentFakeEmbeddings, MockCollection
|
||||
|
||||
INDEX_NAME = "langchain-test-index"
|
||||
@ -81,7 +82,7 @@ class TestMongoDBAtlasVectorSearch:
|
||||
assert loads(dumps(output[0].page_content)) == output[0].page_content
|
||||
assert loads(dumps(output[0].metadata)) == output[0].metadata
|
||||
json_metadata = dumps(output[0].metadata) # normal json.dumps
|
||||
assert isinstance(json_util.loads(json_metadata)["_id"], ObjectId)
|
||||
isinstance(str_to_oid(json_util.loads(json_metadata)["_id"]), ObjectId)
|
||||
|
||||
def test_from_documents(
|
||||
self, embedding_openai: Embeddings, collection: MockCollection
|
||||
|
Loading…
Reference in New Issue
Block a user