diff --git a/libs/partners/mongodb/langchain_mongodb/index.py b/libs/partners/mongodb/langchain_mongodb/index.py index a2c0e10ed66..d0f99511fd7 100644 --- a/libs/partners/mongodb/langchain_mongodb/index.py +++ b/libs/partners/mongodb/langchain_mongodb/index.py @@ -1,11 +1,25 @@ import logging -from typing import Any, Dict, List, Optional +from time import monotonic, sleep +from typing import Any, Callable, Dict, List, Optional from pymongo.collection import Collection +from pymongo.errors import OperationFailure from pymongo.operations import SearchIndexModel logger = logging.getLogger(__file__) +_DELAY = 0.5 # Interval between checks for index operations + + +def _search_index_error_message() -> str: + return ( + "Search index operations are not currently available on shared clusters, " + "such as MO. They require dedicated clusters >= M10. " + "You may still perform vector search. " + "You simply must set up indexes manually. Follow the instructions here: " + "https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/" + ) + def _vector_search_index_definition( dimensions: int, @@ -32,7 +46,9 @@ def create_vector_search_index( dimensions: int, path: str, similarity: str, - filters: List[Dict[str, str]], + filters: Optional[List[Dict[str, str]]] = None, + *, + wait_until_complete: Optional[float] = None, ) -> None: """Experimental Utility function to create a vector search index @@ -43,31 +59,65 @@ def create_vector_search_index( path (str): field with vector embedding similarity (str): The similarity score used for the index filters (List[Dict[str, str]]): additional filters for index definition. + wait_until_complete (Optional[float]): If provided, number of seconds to wait + until search index is ready. """ logger.info("Creating Search Index %s on %s", index_name, collection.name) - result = collection.create_search_index( - SearchIndexModel( - definition=_vector_search_index_definition( - dimensions=dimensions, path=path, similarity=similarity, filters=filters - ), - name=index_name, - type="vectorSearch", + + try: + result = collection.create_search_index( + SearchIndexModel( + definition=_vector_search_index_definition( + dimensions=dimensions, + path=path, + similarity=similarity, + filters=filters, + ), + name=index_name, + type="vectorSearch", + ) + ) + except OperationFailure as e: + raise OperationFailure(_search_index_error_message()) from e + + if wait_until_complete: + _wait_for_predicate( + predicate=lambda: _is_index_ready(collection, index_name), + err=f"Index {index_name} creation did not finish in {wait_until_complete}!", + timeout=wait_until_complete, ) - ) logger.info(result) -def drop_vector_search_index(collection: Collection, index_name: str) -> None: +def drop_vector_search_index( + collection: Collection, + index_name: str, + *, + wait_until_complete: Optional[float] = None, +) -> None: """Drop a created vector search index Args: collection (Collection): MongoDB Collection with index to be dropped index_name (str): Name of the MongoDB index + wait_until_complete (Optional[float]): If provided, number of seconds to wait + until search index is ready. """ logger.info( "Dropping Search Index %s from Collection: %s", index_name, collection.name ) - collection.drop_search_index(index_name) + try: + collection.drop_search_index(index_name) + except OperationFailure as e: + if "CommandNotSupported" in str(e): + raise OperationFailure(_search_index_error_message()) from e + # else this most likely means an ongoing drop request was made so skip + if wait_until_complete: + _wait_for_predicate( + predicate=lambda: len(list(collection.list_search_indexes())) == 0, + err=f"Index {index_name} did not drop in {wait_until_complete}!", + timeout=wait_until_complete, + ) logger.info("Vector Search index %s.%s dropped", collection.name, index_name) @@ -78,8 +128,12 @@ def update_vector_search_index( path: str, similarity: str, filters: List[Dict[str, str]], + *, + wait_until_complete: Optional[float] = None, ) -> None: - """Leverages the updateSearchIndex call + """Update a search index. + + Replace the existing index definition with the provided definition. Args: collection (Collection): MongoDB Collection @@ -88,18 +142,73 @@ def update_vector_search_index( path (str): field with vector embedding. similarity (str): The similarity score used for the index. filters (List[Dict[str, str]]): additional filters for index definition. + wait_until_complete (Optional[float]): If provided, number of seconds to wait + until search index is ready. """ logger.info( "Updating Search Index %s from Collection: %s", index_name, collection.name ) - collection.update_search_index( - name=index_name, - definition=_vector_search_index_definition( - dimensions=dimensions, - path=path, - similarity=similarity, - filters=filters, - ), - ) + try: + collection.update_search_index( + name=index_name, + definition=_vector_search_index_definition( + dimensions=dimensions, + path=path, + similarity=similarity, + filters=filters, + ), + ) + except OperationFailure as e: + raise OperationFailure(_search_index_error_message()) from e + + if wait_until_complete: + _wait_for_predicate( + predicate=lambda: _is_index_ready(collection, index_name), + err=f"Index {index_name} update did not complete in {wait_until_complete}!", + timeout=wait_until_complete, + ) logger.info("Update succeeded") + + +def _is_index_ready(collection: Collection, index_name: str) -> bool: + """Check for the index name in the list of available search indexes to see if the + specified index is of status READY + + Args: + collection (Collection): MongoDB Collection to for the search indexes + index_name (str): Vector Search Index name + + Returns: + bool : True if the index is present and READY false otherwise + """ + try: + search_indexes = collection.list_search_indexes(index_name) + except OperationFailure as e: + raise OperationFailure(_search_index_error_message()) from e + + for index in search_indexes: + if index["type"] == "vectorSearch" and index["status"] == "READY": + return True + return False + + +def _wait_for_predicate( + predicate: Callable, err: str, timeout: float = 120, interval: float = 0.5 +) -> None: + """Generic to block until the predicate returns true + + Args: + predicate (Callable[, bool]): A function that returns a boolean value + err (str): Error message to raise if nothing occurs + timeout (float, optional): wait time for predicate. Defaults to TIMEOUT. + interval (float, optional): Interval to check predicate. Defaults to DELAY. + + Raises: + TimeoutError: _description_ + """ + start = monotonic() + while not predicate(): + if monotonic() - start > timeout: + raise TimeoutError(err) + sleep(interval) diff --git a/libs/partners/mongodb/langchain_mongodb/vectorstores.py b/libs/partners/mongodb/langchain_mongodb/vectorstores.py index c85bce58b85..d177f1653f5 100644 --- a/libs/partners/mongodb/langchain_mongodb/vectorstores.py +++ b/libs/partners/mongodb/langchain_mongodb/vectorstores.py @@ -629,4 +629,4 @@ class MongoDBAtlasVectorSearch(VectorStore): path=self._embedding_key, similarity=self._relevance_score_fn, filters=filters or [], - ) + ) # type: ignore [operator] diff --git a/libs/partners/mongodb/tests/integration_tests/test_index.py b/libs/partners/mongodb/tests/integration_tests/test_index.py new file mode 100644 index 00000000000..dd1697bfa22 --- /dev/null +++ b/libs/partners/mongodb/tests/integration_tests/test_index.py @@ -0,0 +1,73 @@ +"""Search index commands are only supported on Atlas Clusters >=M10""" + +import os + +import pytest +from pymongo import MongoClient +from pymongo.collection import Collection + +from langchain_mongodb import index + + +@pytest.fixture +def collection() -> Collection: + """Depending on uri, this could point to any type of cluster.""" + uri = os.environ.get("MONGODB_ATLAS_URI") + client: MongoClient = MongoClient(uri) + clxn = client["db"].create_collection("collection") + return clxn + + +def test_search_index_commands(collection: Collection) -> None: + index_name = "vector_index" + dimensions = 1536 + path = "embedding" + similarity = "cosine" + filters: list = [] + wait_until_complete = 120 + + for index_info in collection.list_search_indexes(): + index.drop_vector_search_index( + collection, index_info["name"], wait_until_complete=wait_until_complete + ) + + assert len(list(collection.list_search_indexes())) == 0 + + index.create_vector_search_index( + collection=collection, + index_name=index_name, + dimensions=dimensions, + path=path, + similarity=similarity, + filters=filters, + wait_until_complete=wait_until_complete, + ) + + assert index._is_index_ready(collection, index_name) + indexes = list(collection.list_search_indexes()) + assert len(indexes) == 1 + assert indexes[0]["name"] == index_name + + new_similarity = "euclidean" + index.update_vector_search_index( + collection, + index_name, + 1536, + "embedding", + new_similarity, + [], + wait_until_complete=wait_until_complete, + ) + + assert index._is_index_ready(collection, index_name) + indexes = list(collection.list_search_indexes()) + assert len(indexes) == 1 + assert indexes[0]["name"] == index_name + assert indexes[0]["latestDefinition"]["fields"][0]["similarity"] == new_similarity + + index.drop_vector_search_index( + collection, index_name, wait_until_complete=wait_until_complete + ) + + indexes = list(collection.list_search_indexes()) + assert len(indexes) == 0 diff --git a/libs/partners/mongodb/tests/unit_tests/test_index.py b/libs/partners/mongodb/tests/unit_tests/test_index.py new file mode 100644 index 00000000000..61776cf55d2 --- /dev/null +++ b/libs/partners/mongodb/tests/unit_tests/test_index.py @@ -0,0 +1,55 @@ +"""Search index commands are only supported on Atlas Clusters >=M10""" + +import os +from time import sleep + +import pytest +from pymongo import MongoClient +from pymongo.collection import Collection +from pymongo.errors import OperationFailure, ServerSelectionTimeoutError + +from langchain_mongodb import index + + +@pytest.fixture +def collection() -> Collection: + """Depending on uri, this could point to any type of cluster. + + For unit tests, MONGODB_URI should be localhost, None, or Atlas cluster None: + with pytest.raises((OperationFailure, ServerSelectionTimeoutError)): + index.create_vector_search_index( + collection, "index_name", 1536, "embedding", "cosine", [] + ) + + +def test_drop_vector_search_index(collection: Collection) -> None: + with pytest.raises((OperationFailure, ServerSelectionTimeoutError)): + index.drop_vector_search_index(collection, "index_name") + + +def test_update_vector_search_index(collection: Collection) -> None: + with pytest.raises((OperationFailure, ServerSelectionTimeoutError)): + index.update_vector_search_index( + collection, "index_name", 1536, "embedding", "cosine", [] + ) + + +def test___is_index_ready(collection: Collection) -> None: + with pytest.raises((OperationFailure, ServerSelectionTimeoutError)): + index._is_index_ready(collection, "index_name") + + +def test__wait_for_predicate() -> None: + err = "error string" + with pytest.raises(TimeoutError) as e: + index._wait_for_predicate(lambda: sleep(5), err=err, timeout=0.5, interval=0.1) + assert err in str(e) + + index._wait_for_predicate(lambda: True, err=err, timeout=1.0, interval=0.5)