mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 06:14:37 +00:00
partners/mongodb : Significant MongoDBVectorSearch ID enhancements (#23535)
## Description This pull-request improves the treatment of document IDs in `MongoDBAtlasVectorSearch`. Class method signatures of add_documents, add_texts, delete, and from_texts now include an `ids:Optional[List[str]]` keyword argument permitting the user greater control. Note that, as before, IDs may also be inferred from `Document.metadata['_id']` if present, but this is no longer required, IDs can also optionally be returned from searches. This PR closes the following JIRA issues. * [PYTHON-4446](https://jira.mongodb.org/browse/PYTHON-4446) MongoDBVectorSearch delete / add_texts function rework * [PYTHON-4435](https://jira.mongodb.org/browse/PYTHON-4435) Add support for "Indexing" * [PYTHON-4534](https://jira.mongodb.org/browse/PYTHON-4534) Ensure datetimes are json-serializable --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
cc2cbfabfc
commit
a47f69a120
@ -8,10 +8,15 @@ are duplicated in this utility respectively from modules:
|
|||||||
- "libs/community/langchain_community/utils/math.py"
|
- "libs/community/langchain_community/utils/math.py"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Union
|
from datetime import date, datetime
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from bson import ObjectId
|
||||||
|
from bson.errors import InvalidId
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -88,3 +93,55 @@ def maximal_marginal_relevance(
|
|||||||
idxs.append(idx_to_add)
|
idxs.append(idx_to_add)
|
||||||
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
|
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
|
||||||
return idxs
|
return idxs
|
||||||
|
|
||||||
|
|
||||||
|
def str_to_oid(str_repr: str) -> ObjectId | str:
|
||||||
|
"""Attempt to cast string representation of id to MongoDB's internal BSON ObjectId.
|
||||||
|
|
||||||
|
To be consistent with ObjectId, input must be a 24 character hex string.
|
||||||
|
If it is not, MongoDB will happily use the string in the main _id index.
|
||||||
|
Importantly, the str representation that comes out of MongoDB will have this form.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
str_repr: id as string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ObjectID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return ObjectId(str_repr)
|
||||||
|
except InvalidId:
|
||||||
|
logger.debug(
|
||||||
|
"ObjectIds must be 12-character byte or 24-character hex strings. "
|
||||||
|
"Examples: b'heres12bytes', '6f6e6568656c6c6f68656768'"
|
||||||
|
)
|
||||||
|
return str_repr
|
||||||
|
|
||||||
|
|
||||||
|
def oid_to_str(oid: ObjectId) -> str:
|
||||||
|
"""Convert MongoDB's internal BSON ObjectId into a simple str for compatibility.
|
||||||
|
|
||||||
|
Instructive helper to show where data is coming out of MongoDB.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
oid: bson.ObjectId
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
24 character hex string.
|
||||||
|
"""
|
||||||
|
return str(oid)
|
||||||
|
|
||||||
|
|
||||||
|
def make_serializable(
|
||||||
|
obj: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Recursively cast values in a dict to a form able to json.dump"""
|
||||||
|
for k, v in obj.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
make_serializable(v)
|
||||||
|
elif isinstance(v, list) and v and isinstance(v[0], (ObjectId, date, datetime)):
|
||||||
|
obj[k] = [oid_to_str(item) for item in v]
|
||||||
|
elif isinstance(v, ObjectId):
|
||||||
|
obj[k] = oid_to_str(v)
|
||||||
|
elif isinstance(v, (datetime, date)):
|
||||||
|
obj[k] = v.isoformat()
|
||||||
|
@ -16,7 +16,6 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from bson import ObjectId, json_util
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.runnables.config import run_in_executor
|
from langchain_core.runnables.config import run_in_executor
|
||||||
@ -30,7 +29,12 @@ from langchain_mongodb.index import (
|
|||||||
create_vector_search_index,
|
create_vector_search_index,
|
||||||
update_vector_search_index,
|
update_vector_search_index,
|
||||||
)
|
)
|
||||||
from langchain_mongodb.utils import maximal_marginal_relevance
|
from langchain_mongodb.utils import (
|
||||||
|
make_serializable,
|
||||||
|
maximal_marginal_relevance,
|
||||||
|
oid_to_str,
|
||||||
|
str_to_oid,
|
||||||
|
)
|
||||||
|
|
||||||
MongoDBDocumentType = TypeVar("MongoDBDocumentType", bound=Dict[str, Any])
|
MongoDBDocumentType = TypeVar("MongoDBDocumentType", bound=Dict[str, Any])
|
||||||
VST = TypeVar("VST", bound=VectorStore)
|
VST = TypeVar("VST", bound=VectorStore)
|
||||||
@ -143,51 +147,153 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|||||||
self,
|
self,
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Run more texts through the embeddings and add to the vectorstore.
|
"""Add texts, create embeddings, and add to the Collection and index.
|
||||||
|
|
||||||
|
Important notes on ids:
|
||||||
|
- If _id or id is a key in the metadatas dicts, one must
|
||||||
|
pop them and provide as separate list.
|
||||||
|
- They must be unique.
|
||||||
|
- If they are not provided, the VectorStore will create unique ones,
|
||||||
|
stored as bson.ObjectIds internally, and strings in Langchain.
|
||||||
|
These will appear in Document.metadata with key, '_id'.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: Iterable of strings to add to the vectorstore.
|
texts: Iterable of strings to add to the vectorstore.
|
||||||
metadatas: Optional list of metadatas associated with the texts.
|
metadatas: Optional list of metadatas associated with the texts.
|
||||||
|
ids: Optional list of unique ids that will be used as index in VectorStore.
|
||||||
|
See note on ids.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of ids from adding the texts into the vectorstore.
|
List of ids added to the vectorstore.
|
||||||
"""
|
"""
|
||||||
batch_size = kwargs.get("batch_size", DEFAULT_INSERT_BATCH_SIZE)
|
|
||||||
_metadatas: Union[List, Generator] = metadatas or ({} for _ in texts)
|
# Check to see if metadata includes ids
|
||||||
|
if metadatas is not None and (
|
||||||
|
metadatas[0].get("_id") or metadatas[0].get("id")
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"_id or id key found in metadata. "
|
||||||
|
"Please pop from each dict and input as separate list."
|
||||||
|
"Retrieving methods will include the same id as '_id' in metadata."
|
||||||
|
)
|
||||||
|
|
||||||
texts_batch = texts
|
texts_batch = texts
|
||||||
|
_metadatas: Union[List, Generator] = metadatas or ({} for _ in texts)
|
||||||
metadatas_batch = _metadatas
|
metadatas_batch = _metadatas
|
||||||
|
|
||||||
result_ids = []
|
result_ids = []
|
||||||
|
batch_size = kwargs.get("batch_size", DEFAULT_INSERT_BATCH_SIZE)
|
||||||
if batch_size:
|
if batch_size:
|
||||||
texts_batch = []
|
texts_batch = []
|
||||||
metadatas_batch = []
|
metadatas_batch = []
|
||||||
size = 0
|
size = 0
|
||||||
for i, (text, metadata) in enumerate(zip(texts, _metadatas)):
|
i = 0
|
||||||
|
for j, (text, metadata) in enumerate(zip(texts, _metadatas)):
|
||||||
size += len(text) + len(metadata)
|
size += len(text) + len(metadata)
|
||||||
texts_batch.append(text)
|
texts_batch.append(text)
|
||||||
metadatas_batch.append(metadata)
|
metadatas_batch.append(metadata)
|
||||||
if (i + 1) % batch_size == 0 or size >= 47_000_000:
|
if (j + 1) % batch_size == 0 or size >= 47_000_000:
|
||||||
result_ids.extend(self._insert_texts(texts_batch, metadatas_batch))
|
if ids:
|
||||||
|
batch_res = self.bulk_embed_and_insert_texts(
|
||||||
|
texts_batch, metadatas_batch, ids[i : j + 1]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch_res = self.bulk_embed_and_insert_texts(
|
||||||
|
texts_batch, metadatas_batch
|
||||||
|
)
|
||||||
|
result_ids.extend(batch_res)
|
||||||
texts_batch = []
|
texts_batch = []
|
||||||
metadatas_batch = []
|
metadatas_batch = []
|
||||||
size = 0
|
size = 0
|
||||||
|
i = j + 1
|
||||||
if texts_batch:
|
if texts_batch:
|
||||||
result_ids.extend(self._insert_texts(texts_batch, metadatas_batch)) # type: ignore
|
if ids:
|
||||||
return [str(id) for id in result_ids]
|
batch_res = self.bulk_embed_and_insert_texts(
|
||||||
|
texts_batch, metadatas_batch, ids[i : j + 1]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch_res = self.bulk_embed_and_insert_texts(
|
||||||
|
texts_batch, metadatas_batch
|
||||||
|
)
|
||||||
|
result_ids.extend(batch_res)
|
||||||
|
return result_ids
|
||||||
|
|
||||||
def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> List:
|
def bulk_embed_and_insert_texts(
|
||||||
|
self,
|
||||||
|
texts: Union[List[str], Iterable[str]],
|
||||||
|
metadatas: Union[List[dict], Generator[dict, Any, Any]],
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Bulk insert single batch of texts, embeddings, and optionally ids.
|
||||||
|
|
||||||
|
See add_texts for additional details.
|
||||||
|
"""
|
||||||
if not texts:
|
if not texts:
|
||||||
return []
|
return []
|
||||||
# Embed and create the documents
|
# Compute embedding vectors
|
||||||
embeddings = self._embedding.embed_documents(texts)
|
embeddings = self._embedding.embed_documents(texts) # type: ignore
|
||||||
|
if ids:
|
||||||
|
to_insert = [
|
||||||
|
{
|
||||||
|
"_id": str_to_oid(i),
|
||||||
|
self._text_key: t,
|
||||||
|
self._embedding_key: embedding,
|
||||||
|
**m,
|
||||||
|
}
|
||||||
|
for i, t, m, embedding in zip(ids, texts, metadatas, embeddings)
|
||||||
|
]
|
||||||
|
else:
|
||||||
to_insert = [
|
to_insert = [
|
||||||
{self._text_key: t, self._embedding_key: embedding, **m}
|
{self._text_key: t, self._embedding_key: embedding, **m}
|
||||||
for t, m, embedding in zip(texts, metadatas, embeddings)
|
for t, m, embedding in zip(texts, metadatas, embeddings)
|
||||||
]
|
]
|
||||||
# insert the documents in MongoDB Atlas
|
# insert the documents in MongoDB Atlas
|
||||||
insert_result = self._collection.insert_many(to_insert) # type: ignore
|
insert_result = self._collection.insert_many(to_insert) # type: ignore
|
||||||
return insert_result.inserted_ids
|
return [oid_to_str(_id) for _id in insert_result.inserted_ids]
|
||||||
|
|
||||||
|
def add_documents(
|
||||||
|
self,
|
||||||
|
documents: List[Document],
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
batch_size: int = DEFAULT_INSERT_BATCH_SIZE,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Add documents to the vectorstore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: Documents to add to the vectorstore.
|
||||||
|
ids: Optional list of unique ids that will be used as index in VectorStore.
|
||||||
|
See note on ids in add_texts.
|
||||||
|
batch_size: Number of documents to insert at a time.
|
||||||
|
Tuning this may help with performance and sidestep MongoDB limits.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of IDs of the added texts.
|
||||||
|
"""
|
||||||
|
n_docs = len(documents)
|
||||||
|
if ids:
|
||||||
|
assert len(ids) == n_docs, "Number of ids must equal number of documents."
|
||||||
|
result_ids = []
|
||||||
|
start = 0
|
||||||
|
for end in range(batch_size, n_docs + batch_size, batch_size):
|
||||||
|
texts, metadatas = zip(
|
||||||
|
*[(doc.page_content, doc.metadata) for doc in documents[start:end]]
|
||||||
|
)
|
||||||
|
if ids:
|
||||||
|
result_ids.extend(
|
||||||
|
self.bulk_embed_and_insert_texts(
|
||||||
|
texts=texts, metadatas=metadatas, ids=ids[start:end]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result_ids.extend(
|
||||||
|
self.bulk_embed_and_insert_texts(texts=texts, metadatas=metadatas)
|
||||||
|
)
|
||||||
|
start = end
|
||||||
|
return result_ids
|
||||||
|
|
||||||
def _similarity_search_with_score(
|
def _similarity_search_with_score(
|
||||||
self,
|
self,
|
||||||
@ -196,8 +302,10 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|||||||
pre_filter: Optional[Dict] = None,
|
pre_filter: Optional[Dict] = None,
|
||||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
post_filter_pipeline: Optional[List[Dict]] = None,
|
||||||
include_embedding: bool = False,
|
include_embedding: bool = False,
|
||||||
|
include_ids: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Core implementation."""
|
||||||
params = {
|
params = {
|
||||||
"queryVector": embedding,
|
"queryVector": embedding,
|
||||||
"path": self._embedding_key,
|
"path": self._embedding_key,
|
||||||
@ -223,22 +331,10 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|||||||
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
|
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
|
||||||
docs = []
|
docs = []
|
||||||
|
|
||||||
def _make_serializable(obj: Dict[str, Any]) -> None:
|
|
||||||
for k, v in obj.items():
|
|
||||||
if isinstance(v, dict):
|
|
||||||
_make_serializable(v)
|
|
||||||
elif isinstance(v, list) and v and isinstance(v[0], ObjectId):
|
|
||||||
obj[k] = [json_util.default(item) for item in v]
|
|
||||||
elif isinstance(v, ObjectId):
|
|
||||||
obj[k] = json_util.default(v)
|
|
||||||
|
|
||||||
for res in cursor:
|
for res in cursor:
|
||||||
text = res.pop(self._text_key)
|
text = res.pop(self._text_key)
|
||||||
score = res.pop("score")
|
score = res.pop("score")
|
||||||
# Make every ObjectId found JSON-Serializable
|
make_serializable(res)
|
||||||
# following format used in bson.json_util.loads
|
|
||||||
# e.g. loads('{"_id": {"$oid": "664..."}}') == {'_id': ObjectId('664..')} # noqa: E501
|
|
||||||
_make_serializable(res)
|
|
||||||
docs.append((Document(page_content=text, metadata=res), score))
|
docs.append((Document(page_content=text, metadata=res), score))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@ -363,6 +459,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
metadatas: Optional[List[Dict]] = None,
|
metadatas: Optional[List[Dict]] = None,
|
||||||
collection: Optional[Collection[MongoDBDocumentType]] = None,
|
collection: Optional[Collection[MongoDBDocumentType]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MongoDBAtlasVectorSearch:
|
) -> MongoDBAtlasVectorSearch:
|
||||||
"""Construct a `MongoDB Atlas Vector Search` vector store from raw documents.
|
"""Construct a `MongoDB Atlas Vector Search` vector store from raw documents.
|
||||||
@ -394,25 +491,25 @@ class MongoDBAtlasVectorSearch(VectorStore):
|
|||||||
if collection is None:
|
if collection is None:
|
||||||
raise ValueError("Must provide 'collection' named parameter.")
|
raise ValueError("Must provide 'collection' named parameter.")
|
||||||
vectorstore = cls(collection, embedding, **kwargs)
|
vectorstore = cls(collection, embedding, **kwargs)
|
||||||
vectorstore.add_texts(texts, metadatas=metadatas)
|
vectorstore.add_texts(texts=texts, metadatas=metadatas, ids=ids, **kwargs)
|
||||||
return vectorstore
|
return vectorstore
|
||||||
|
|
||||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
||||||
"""Delete by ObjectId or other criteria.
|
"""Delete documents from VectorStore by ids.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ids: List of ids to delete.
|
ids: List of ids to delete.
|
||||||
**kwargs: Other keyword arguments that subclasses might use.
|
**kwargs: Other keyword arguments passed to Collection.delete_many()
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[bool]: True if deletion is successful,
|
Optional[bool]: True if deletion is successful,
|
||||||
False otherwise, None if not implemented.
|
False otherwise, None if not implemented.
|
||||||
"""
|
"""
|
||||||
search_params: dict[str, Any] = {}
|
filter = {}
|
||||||
if ids:
|
if ids:
|
||||||
search_params["_id"] = {"$in": [ObjectId(id) for id in ids]}
|
oids = [str_to_oid(i) for i in ids]
|
||||||
|
filter = {"_id": {"$in": oids}}
|
||||||
return self._collection.delete_many({**search_params, **kwargs}).acknowledged
|
return self._collection.delete_many(filter=filter, **kwargs).acknowledged
|
||||||
|
|
||||||
async def adelete(
|
async def adelete(
|
||||||
self, ids: Optional[List[str]] = None, **kwargs: Any
|
self, ids: Optional[List[str]] = None, **kwargs: Any
|
||||||
|
@ -4,9 +4,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from time import monotonic, sleep
|
from time import monotonic, sleep
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
|
||||||
|
|
||||||
import pytest # type: ignore[import-not-found]
|
import pytest # type: ignore[import-not-found]
|
||||||
|
from bson import ObjectId
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
@ -15,7 +16,9 @@ from pymongo.errors import OperationFailure
|
|||||||
|
|
||||||
from langchain_mongodb import MongoDBAtlasVectorSearch
|
from langchain_mongodb import MongoDBAtlasVectorSearch
|
||||||
from langchain_mongodb.index import drop_vector_search_index
|
from langchain_mongodb.index import drop_vector_search_index
|
||||||
from tests.utils import ConsistentFakeEmbeddings
|
from langchain_mongodb.utils import oid_to_str
|
||||||
|
|
||||||
|
from ..utils import ConsistentFakeEmbeddings
|
||||||
|
|
||||||
INDEX_NAME = "langchain-test-index-vectorstores"
|
INDEX_NAME = "langchain-test-index-vectorstores"
|
||||||
INDEX_CREATION_NAME = "langchain-test-index-vectorstores-create-test"
|
INDEX_CREATION_NAME = "langchain-test-index-vectorstores-create-test"
|
||||||
@ -25,20 +28,25 @@ DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
|
|||||||
INDEX_COLLECTION_NAME = "langchain_test_vectorstores_index"
|
INDEX_COLLECTION_NAME = "langchain_test_vectorstores_index"
|
||||||
INDEX_DB_NAME = "langchain_test_index_db"
|
INDEX_DB_NAME = "langchain_test_index_db"
|
||||||
DIMENSIONS = 1536
|
DIMENSIONS = 1536
|
||||||
TIMEOUT = 10.0
|
TIMEOUT = 120.0
|
||||||
INTERVAL = 0.5
|
INTERVAL = 0.5
|
||||||
|
|
||||||
|
|
||||||
class PatchedMongoDBAtlasVectorSearch(MongoDBAtlasVectorSearch):
|
class PatchedMongoDBAtlasVectorSearch(MongoDBAtlasVectorSearch):
|
||||||
def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> List:
|
def bulk_embed_and_insert_texts(
|
||||||
|
self,
|
||||||
|
texts: Union[List[str], Iterable[str]],
|
||||||
|
metadatas: Union[List[dict], Generator[dict, Any, Any]],
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
) -> List:
|
||||||
"""Patched insert_texts that waits for data to be indexed before returning"""
|
"""Patched insert_texts that waits for data to be indexed before returning"""
|
||||||
ids = super()._insert_texts(texts, metadatas)
|
ids_inserted = super().bulk_embed_and_insert_texts(texts, metadatas, ids)
|
||||||
start = monotonic()
|
start = monotonic()
|
||||||
while len(ids) != self.similarity_search("sandwich") and (
|
while len(ids_inserted) != len(self.similarity_search("sandwich")) and (
|
||||||
monotonic() - start <= TIMEOUT
|
monotonic() - start <= TIMEOUT
|
||||||
):
|
):
|
||||||
sleep(INTERVAL)
|
sleep(INTERVAL)
|
||||||
return ids
|
return ids_inserted
|
||||||
|
|
||||||
def create_vector_search_index(
|
def create_vector_search_index(
|
||||||
self,
|
self,
|
||||||
@ -87,6 +95,16 @@ def collection() -> Collection:
|
|||||||
return get_collection()
|
return get_collection()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def texts() -> List[str]:
|
||||||
|
return [
|
||||||
|
"Dogs are tough.",
|
||||||
|
"Cats have fluff.",
|
||||||
|
"What is a sandwich?",
|
||||||
|
"That fence is purple.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def index_collection() -> Collection:
|
def index_collection() -> Collection:
|
||||||
return get_collection(INDEX_DB_NAME, INDEX_COLLECTION_NAME)
|
return get_collection(INDEX_DB_NAME, INDEX_COLLECTION_NAME)
|
||||||
@ -118,12 +136,10 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def embedding_openai(self) -> Embeddings:
|
def embeddings(self) -> Embeddings:
|
||||||
return ConsistentFakeEmbeddings(DIMENSIONS)
|
return ConsistentFakeEmbeddings(DIMENSIONS)
|
||||||
|
|
||||||
def test_from_documents(
|
def test_from_documents(self, embeddings: Embeddings, collection: Any) -> None:
|
||||||
self, embedding_openai: Embeddings, collection: Any
|
|
||||||
) -> None:
|
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
documents = [
|
documents = [
|
||||||
Document(page_content="Dogs are tough.", metadata={"a": 1}),
|
Document(page_content="Dogs are tough.", metadata={"a": 1}),
|
||||||
@ -133,7 +149,7 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
]
|
]
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
|
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
|
||||||
documents,
|
documents,
|
||||||
embedding_openai,
|
embeddings,
|
||||||
collection=collection,
|
collection=collection,
|
||||||
index_name=INDEX_NAME,
|
index_name=INDEX_NAME,
|
||||||
)
|
)
|
||||||
@ -143,7 +159,7 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
assert any([key.page_content == output[0].page_content for key in documents])
|
assert any([key.page_content == output[0].page_content for key in documents])
|
||||||
|
|
||||||
def test_from_documents_no_embedding_return(
|
def test_from_documents_no_embedding_return(
|
||||||
self, embedding_openai: Embeddings, collection: Any
|
self, embeddings: Embeddings, collection: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
documents = [
|
documents = [
|
||||||
@ -154,7 +170,7 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
]
|
]
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
|
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
|
||||||
documents,
|
documents,
|
||||||
embedding_openai,
|
embeddings,
|
||||||
collection=collection,
|
collection=collection,
|
||||||
index_name=INDEX_NAME,
|
index_name=INDEX_NAME,
|
||||||
)
|
)
|
||||||
@ -166,7 +182,7 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
assert any([key.page_content == output[0].page_content for key in documents])
|
assert any([key.page_content == output[0].page_content for key in documents])
|
||||||
|
|
||||||
def test_from_documents_embedding_return(
|
def test_from_documents_embedding_return(
|
||||||
self, embedding_openai: Embeddings, collection: Any
|
self, embeddings: Embeddings, collection: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
documents = [
|
documents = [
|
||||||
@ -177,7 +193,7 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
]
|
]
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
|
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
|
||||||
documents,
|
documents,
|
||||||
embedding_openai,
|
embeddings,
|
||||||
collection=collection,
|
collection=collection,
|
||||||
index_name=INDEX_NAME,
|
index_name=INDEX_NAME,
|
||||||
)
|
)
|
||||||
@ -188,16 +204,12 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
# Check for the presence of the metadata key
|
# Check for the presence of the metadata key
|
||||||
assert any([key.page_content == output[0].page_content for key in documents])
|
assert any([key.page_content == output[0].page_content for key in documents])
|
||||||
|
|
||||||
def test_from_texts(self, embedding_openai: Embeddings, collection: Any) -> None:
|
def test_from_texts(
|
||||||
texts = [
|
self, embeddings: Embeddings, collection: Collection, texts: List[str]
|
||||||
"Dogs are tough.",
|
) -> None:
|
||||||
"Cats have fluff.",
|
|
||||||
"What is a sandwich?",
|
|
||||||
"That fence is purple.",
|
|
||||||
]
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
||||||
texts,
|
texts,
|
||||||
embedding_openai,
|
embeddings,
|
||||||
collection=collection,
|
collection=collection,
|
||||||
index_name=INDEX_NAME,
|
index_name=INDEX_NAME,
|
||||||
)
|
)
|
||||||
@ -205,19 +217,16 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
assert len(output) == 1
|
assert len(output) == 1
|
||||||
|
|
||||||
def test_from_texts_with_metadatas(
|
def test_from_texts_with_metadatas(
|
||||||
self, embedding_openai: Embeddings, collection: Any
|
self,
|
||||||
|
embeddings: Embeddings,
|
||||||
|
collection: Collection,
|
||||||
|
texts: List[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
texts = [
|
|
||||||
"Dogs are tough.",
|
|
||||||
"Cats have fluff.",
|
|
||||||
"What is a sandwich?",
|
|
||||||
"The fence is purple.",
|
|
||||||
]
|
|
||||||
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
||||||
metakeys = ["a", "b", "c", "d", "e"]
|
metakeys = ["a", "b", "c", "d", "e"]
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
||||||
texts,
|
texts,
|
||||||
embedding_openai,
|
embeddings,
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
collection=collection,
|
collection=collection,
|
||||||
index_name=INDEX_NAME,
|
index_name=INDEX_NAME,
|
||||||
@ -228,18 +237,12 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
assert any([key in output[0].metadata for key in metakeys])
|
assert any([key in output[0].metadata for key in metakeys])
|
||||||
|
|
||||||
def test_from_texts_with_metadatas_and_pre_filter(
|
def test_from_texts_with_metadatas_and_pre_filter(
|
||||||
self, embedding_openai: Embeddings, collection: Any
|
self, embeddings: Embeddings, collection: Any, texts: List[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
texts = [
|
|
||||||
"Dogs are tough.",
|
|
||||||
"Cats have fluff.",
|
|
||||||
"What is a sandwich?",
|
|
||||||
"The fence is purple.",
|
|
||||||
]
|
|
||||||
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
||||||
texts,
|
texts,
|
||||||
embedding_openai,
|
embeddings,
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
collection=collection,
|
collection=collection,
|
||||||
index_name=INDEX_NAME,
|
index_name=INDEX_NAME,
|
||||||
@ -249,11 +252,11 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
)
|
)
|
||||||
assert output == []
|
assert output == []
|
||||||
|
|
||||||
def test_mmr(self, embedding_openai: Embeddings, collection: Any) -> None:
|
def test_mmr(self, embeddings: Embeddings, collection: Any) -> None:
|
||||||
texts = ["foo", "foo", "fou", "foy"]
|
texts = ["foo", "foo", "fou", "foy"]
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
||||||
texts,
|
texts,
|
||||||
embedding_openai,
|
embeddings,
|
||||||
collection=collection,
|
collection=collection,
|
||||||
index_name=INDEX_NAME,
|
index_name=INDEX_NAME,
|
||||||
)
|
)
|
||||||
@ -263,19 +266,175 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
assert output[0].page_content == "foo"
|
assert output[0].page_content == "foo"
|
||||||
assert output[1].page_content != "foo"
|
assert output[1].page_content != "foo"
|
||||||
|
|
||||||
|
def test_delete(
|
||||||
|
self, embeddings: Embeddings, collection: Any, texts: List[str]
|
||||||
|
) -> None:
|
||||||
|
vectorstore = MongoDBAtlasVectorSearch( # PatchedMongoDBAtlasVectorSearch(
|
||||||
|
collection=collection,
|
||||||
|
embedding=embeddings,
|
||||||
|
index_name=INDEX_NAME,
|
||||||
|
)
|
||||||
|
clxn: Collection = vectorstore._collection
|
||||||
|
assert clxn.count_documents({}) == 0
|
||||||
|
ids = vectorstore.add_texts(texts)
|
||||||
|
assert clxn.count_documents({}) == len(texts)
|
||||||
|
|
||||||
|
deleted = vectorstore.delete(ids[-2:])
|
||||||
|
assert deleted
|
||||||
|
assert clxn.count_documents({}) == len(texts) - 2
|
||||||
|
|
||||||
|
new_ids = vectorstore.add_texts(["Pigs eat stuff", "Pigs eat sandwiches"])
|
||||||
|
assert set(new_ids).intersection(set(ids)) == set() # new ids will be unique.
|
||||||
|
assert isinstance(new_ids, list)
|
||||||
|
assert all(isinstance(i, str) for i in new_ids)
|
||||||
|
assert len(new_ids) == 2
|
||||||
|
assert clxn.count_documents({}) == 4
|
||||||
|
|
||||||
|
def test_add_texts(
|
||||||
|
self,
|
||||||
|
embeddings: Embeddings,
|
||||||
|
collection: Collection,
|
||||||
|
texts: List[str],
|
||||||
|
) -> None:
|
||||||
|
"""Tests API of add_texts, focussing on id treatment"""
|
||||||
|
metadatas: List[Dict[str, Any]] = [
|
||||||
|
{"a": 1},
|
||||||
|
{"b": 1},
|
||||||
|
{"c": 1},
|
||||||
|
{"d": 1, "e": 2},
|
||||||
|
]
|
||||||
|
|
||||||
|
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
||||||
|
collection=collection, embedding=embeddings, index_name=INDEX_NAME
|
||||||
|
)
|
||||||
|
|
||||||
|
# Case 1. Add texts without ids
|
||||||
|
provided_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas)
|
||||||
|
all_docs = list(vectorstore._collection.find({}))
|
||||||
|
assert all("_id" in doc for doc in all_docs)
|
||||||
|
docids = set(doc["_id"] for doc in all_docs)
|
||||||
|
assert all(isinstance(_id, ObjectId) for _id in docids) #
|
||||||
|
assert set(provided_ids) == set(oid_to_str(oid) for oid in docids)
|
||||||
|
|
||||||
|
# Case 2: Test Document.metadata looks right. i.e. contains _id
|
||||||
|
search_res = vectorstore.similarity_search_with_score("sandwich", k=1)
|
||||||
|
doc, score = search_res[0]
|
||||||
|
assert "_id" in doc.metadata
|
||||||
|
|
||||||
|
# Case 3: Add new ids that are 24-char hex strings
|
||||||
|
hex_ids = [oid_to_str(ObjectId()) for _ in range(2)]
|
||||||
|
hex_texts = ["Text for hex_id"] * len(hex_ids)
|
||||||
|
out_ids = vectorstore.add_texts(texts=hex_texts, ids=hex_ids)
|
||||||
|
assert set(out_ids) == set(hex_ids)
|
||||||
|
assert collection.count_documents({}) == len(texts) + len(hex_texts)
|
||||||
|
assert all(
|
||||||
|
isinstance(doc["_id"], ObjectId) for doc in vectorstore._collection.find({})
|
||||||
|
)
|
||||||
|
|
||||||
|
# Case 4: Add new ids that cannot be cast to ObjectId
|
||||||
|
# - We can still index and search on them
|
||||||
|
str_ids = ["Sandwiches are beautiful,", "..sandwiches are fine."]
|
||||||
|
str_texts = str_ids # No reason for them to differ
|
||||||
|
out_ids = vectorstore.add_texts(texts=str_texts, ids=str_ids)
|
||||||
|
assert set(out_ids) == set(str_ids)
|
||||||
|
assert collection.count_documents({}) == 8
|
||||||
|
res = vectorstore.similarity_search("sandwich", k=8)
|
||||||
|
assert any(str_ids[0] in doc.metadata["_id"] for doc in res)
|
||||||
|
|
||||||
|
# Case 5: Test adding in multiple batches
|
||||||
|
batch_size = 2
|
||||||
|
batch_ids = [oid_to_str(ObjectId()) for _ in range(2 * batch_size)]
|
||||||
|
batch_texts = [f"Text for batch text {i}" for i in range(2 * batch_size)]
|
||||||
|
out_ids = vectorstore.add_texts(
|
||||||
|
texts=batch_texts, ids=batch_ids, batch_size=batch_size
|
||||||
|
)
|
||||||
|
assert set(out_ids) == set(batch_ids)
|
||||||
|
assert collection.count_documents({}) == 12
|
||||||
|
|
||||||
|
# Case 6: _ids in metadata
|
||||||
|
collection.delete_many({})
|
||||||
|
# 6a. Unique _id in metadata, but ids=None
|
||||||
|
# Will be added as if ids kwarg provided
|
||||||
|
i = 0
|
||||||
|
n = len(texts)
|
||||||
|
assert len(metadatas) == n
|
||||||
|
_ids = [str(i) for i in range(n)]
|
||||||
|
for md in metadatas:
|
||||||
|
md["_id"] = _ids[i]
|
||||||
|
i += 1
|
||||||
|
returned_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas)
|
||||||
|
assert returned_ids == ["0", "1", "2", "3"]
|
||||||
|
assert set(d["_id"] for d in vectorstore._collection.find({})) == set(_ids)
|
||||||
|
|
||||||
|
# 6b. Unique "id", not "_id", but ids=None
|
||||||
|
# New ids will be assigned
|
||||||
|
i = 1
|
||||||
|
for md in metadatas:
|
||||||
|
md.pop("_id")
|
||||||
|
md["id"] = f"{1}"
|
||||||
|
i += 1
|
||||||
|
returned_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas)
|
||||||
|
assert len(set(returned_ids).intersection(set(_ids))) == 0
|
||||||
|
|
||||||
|
def test_add_documents(
|
||||||
|
self,
|
||||||
|
embeddings: Embeddings,
|
||||||
|
collection: Collection,
|
||||||
|
index_name: str = INDEX_NAME,
|
||||||
|
) -> None:
|
||||||
|
"""Tests add_documents."""
|
||||||
|
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
||||||
|
collection=collection, embedding=embeddings, index_name=INDEX_NAME
|
||||||
|
)
|
||||||
|
|
||||||
|
# Case 1: No ids
|
||||||
|
n_docs = 10
|
||||||
|
batch_size = 3
|
||||||
|
docs = [
|
||||||
|
Document(page_content=f"document {i}", metadata={"i": i})
|
||||||
|
for i in range(n_docs)
|
||||||
|
]
|
||||||
|
result_ids = vectorstore.add_documents(docs, batch_size=batch_size)
|
||||||
|
assert len(result_ids) == n_docs
|
||||||
|
assert collection.count_documents({}) == n_docs
|
||||||
|
|
||||||
|
# Case 2: ids
|
||||||
|
collection.delete_many({})
|
||||||
|
n_docs = 10
|
||||||
|
batch_size = 3
|
||||||
|
docs = [
|
||||||
|
Document(page_content=f"document {i}", metadata={"i": i})
|
||||||
|
for i in range(n_docs)
|
||||||
|
]
|
||||||
|
ids = [str(i) for i in range(n_docs)]
|
||||||
|
result_ids = vectorstore.add_documents(docs, ids, batch_size=batch_size)
|
||||||
|
assert len(result_ids) == n_docs
|
||||||
|
assert set(ids) == set(collection.distinct("_id"))
|
||||||
|
|
||||||
|
# Case 3: Single batch
|
||||||
|
collection.delete_many({})
|
||||||
|
n_docs = 3
|
||||||
|
batch_size = 10
|
||||||
|
docs = [
|
||||||
|
Document(page_content=f"document {i}", metadata={"i": i})
|
||||||
|
for i in range(n_docs)
|
||||||
|
]
|
||||||
|
ids = [str(i) for i in range(n_docs)]
|
||||||
|
result_ids = vectorstore.add_documents(docs, ids, batch_size=batch_size)
|
||||||
|
assert len(result_ids) == n_docs
|
||||||
|
assert set(ids) == set(collection.distinct("_id"))
|
||||||
|
|
||||||
def test_index_creation(
|
def test_index_creation(
|
||||||
self, embedding_openai: Embeddings, index_collection: Any
|
self, embeddings: Embeddings, index_collection: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
||||||
index_collection, embedding_openai, index_name=INDEX_CREATION_NAME
|
index_collection, embeddings, index_name=INDEX_CREATION_NAME
|
||||||
)
|
)
|
||||||
vectorstore.create_vector_search_index(dimensions=1536)
|
vectorstore.create_vector_search_index(dimensions=1536)
|
||||||
|
|
||||||
def test_index_update(
|
def test_index_update(self, embeddings: Embeddings, index_collection: Any) -> None:
|
||||||
self, embedding_openai: Embeddings, index_collection: Any
|
|
||||||
) -> None:
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
||||||
index_collection, embedding_openai, index_name=INDEX_CREATION_NAME
|
index_collection, embeddings, index_name=INDEX_CREATION_NAME
|
||||||
)
|
)
|
||||||
vectorstore.create_vector_search_index(dimensions=1536)
|
vectorstore.create_vector_search_index(dimensions=1536)
|
||||||
vectorstore.create_vector_search_index(dimensions=1536, update=True)
|
vectorstore.create_vector_search_index(dimensions=1536, update=True)
|
||||||
|
@ -8,6 +8,7 @@ from langchain_core.embeddings import Embeddings
|
|||||||
from pymongo.collection import Collection
|
from pymongo.collection import Collection
|
||||||
|
|
||||||
from langchain_mongodb import MongoDBAtlasVectorSearch
|
from langchain_mongodb import MongoDBAtlasVectorSearch
|
||||||
|
from langchain_mongodb.utils import str_to_oid
|
||||||
from tests.utils import ConsistentFakeEmbeddings, MockCollection
|
from tests.utils import ConsistentFakeEmbeddings, MockCollection
|
||||||
|
|
||||||
INDEX_NAME = "langchain-test-index"
|
INDEX_NAME = "langchain-test-index"
|
||||||
@ -81,7 +82,7 @@ class TestMongoDBAtlasVectorSearch:
|
|||||||
assert loads(dumps(output[0].page_content)) == output[0].page_content
|
assert loads(dumps(output[0].page_content)) == output[0].page_content
|
||||||
assert loads(dumps(output[0].metadata)) == output[0].metadata
|
assert loads(dumps(output[0].metadata)) == output[0].metadata
|
||||||
json_metadata = dumps(output[0].metadata) # normal json.dumps
|
json_metadata = dumps(output[0].metadata) # normal json.dumps
|
||||||
assert isinstance(json_util.loads(json_metadata)["_id"], ObjectId)
|
isinstance(str_to_oid(json_util.loads(json_metadata)["_id"]), ObjectId)
|
||||||
|
|
||||||
def test_from_documents(
|
def test_from_documents(
|
||||||
self, embedding_openai: Embeddings, collection: MockCollection
|
self, embedding_openai: Embeddings, collection: MockCollection
|
||||||
|
Loading…
Reference in New Issue
Block a user